diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index bdf9fab1a..66135c02b 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -299,7 +299,6 @@ namespace storm { STORM_PRINT("Exploration threshold: " << explorationThreshold << std::endl) } while (!beliefsToBeExpanded.empty()) { - // TODO direct generation of transition matrix? uint64_t currId = beliefsToBeExpanded.front(); beliefsToBeExpanded.pop_front(); bool isTarget = beliefIsTarget[currId]; @@ -637,7 +636,7 @@ namespace storm { // Belief ID -> ActionIndex std::map> chosenActions; - // Belief ID -> Observation -> Probability + // Belief ID -> action -> Observation -> Probability std::map>> observationProbabilities; // current ID -> action -> next ID std::map>> nextBelieves; @@ -1025,7 +1024,7 @@ namespace storm { storm::models::sparse::Pomdp const &pomdp, std::set const &target_observations, uint64_t gridResolution, std::vector> &beliefList, - std::vector> &grid, std::vector &beliefIsKnown, + std::vector> &grid, std::vector &beliefIsTarget, uint64_t nextId) { bool isTarget; uint64_t newId = nextId; @@ -1045,7 +1044,7 @@ namespace storm { << distribution << "]"); beliefList.push_back(belief); grid.push_back(belief); - beliefIsKnown.push_back(isTarget); + beliefIsTarget.push_back(isTarget); ++newId; } else { // Otherwise we have to enumerate all possible distributions with regards to the grid @@ -1073,7 +1072,7 @@ namespace storm { STORM_LOG_TRACE("Add Belief " << std::to_string(newId) << " [(" << std::to_string(observation) << ")," << distribution << "]"); beliefList.push_back(belief); grid.push_back(belief); - beliefIsKnown.push_back(isTarget); + beliefIsTarget.push_back(isTarget); if (helper[statesWithObservation.size() - 1] == storm::utility::convertNumber(gridResolution)) { // If the last entry of helper is the gridResolution, we have enumerated all necessary distributions @@ -1175,7 +1174,16 @@ namespace storm { uint64_t actionIndex) { std::map res; // the id is not important here as we immediately discard the belief (very hacky, I don't like it either) - std::map postProbabilities = getBeliefAfterAction(pomdp, belief, actionIndex, 0).probabilities; + std::map postProbabilities; + for (auto const &probEntry : belief.probabilities) { + uint64_t state = probEntry.first; + auto row = pomdp.getTransitionMatrix().getRow(pomdp.getChoiceIndex(storm::storage::StateActionPair(state, actionIndex))); + for (auto const &entry : row) { + if (entry.getValue() > 0) { + postProbabilities[entry.getColumn()] += belief.probabilities[state] * entry.getValue(); + } + } + } for (auto const &probEntry : postProbabilities) { uint32_t observation = pomdp.getObservation(probEntry.first); if (res.count(observation) == 0) { @@ -1212,7 +1220,6 @@ namespace storm { storm::models::sparse::Pomdp const &pomdp, std::vector> &beliefList, std::vector &beliefIsTarget, std::set const &targetObservations, storm::pomdp::Belief &belief, uint64_t actionIndex, uint32_t observation, uint64_t id) { - storm::utility::Stopwatch distrWatch(true); std::map distributionAfter; for (auto const &probEntry : belief.probabilities) { uint64_t state = probEntry.first; @@ -1223,9 +1230,7 @@ namespace storm { } } } - distrWatch.stop(); // We have to normalize the distribution - storm::utility::Stopwatch normalizationWatch(true); auto sum = storm::utility::zero(); for (auto const &entry : distributionAfter) { sum += entry.second; @@ -1234,19 +1239,12 @@ namespace storm { for (auto const &entry : distributionAfter) { distributionAfter[entry.first] /= sum; } - normalizationWatch.stop(); if (getBeliefIdInVector(beliefList, observation, distributionAfter) != uint64_t(-1)) { - storm::utility::Stopwatch getWatch(true); auto res = getBeliefIdInVector(beliefList, observation, distributionAfter); - getWatch.stop(); - //STORM_PRINT("Distribution: "<< distrWatch.getTimeInNanoseconds() << " / Normalization: " << normalizationWatch.getTimeInNanoseconds() << " / getId: " << getWatch.getTimeInNanoseconds() << std::endl) return res; } else { - storm::utility::Stopwatch pushWatch(true); beliefList.push_back(storm::pomdp::Belief{id, observation, distributionAfter}); beliefIsTarget.push_back(targetObservations.find(observation) != targetObservations.end()); - pushWatch.stop(); - //STORM_PRINT("Distribution: "<< distrWatch.getTimeInNanoseconds() << " / Normalization: " << normalizationWatch.getTimeInNanoseconds() << " / generateBelief: " << pushWatch.getTimeInNanoseconds() << std::endl) return id; } } diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h index 5108d3135..eb3247de1 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h @@ -215,7 +215,7 @@ namespace storm { std::set const &target_observations, uint64_t gridResolution, std::vector> &beliefList, std::vector> &grid, - std::vector &beliefIsKnown, uint64_t nextId); + std::vector &beliefIsTarget, uint64_t nextId); /**