diff --git a/src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp b/src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp index 09bece2c7..065b5f127 100644 --- a/src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp +++ b/src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp @@ -35,20 +35,28 @@ namespace storm { riskPerState = risk; } + template + uint64_t BeliefStateManager::getFreshId() { + beliefIdCounter++; + std::cout << "provide " << beliefIdCounter; + return beliefIdCounter; + } + template SparseBeliefState::SparseBeliefState(std::shared_ptr> const& manager, uint64_t state) - : manager(manager), belief() + : manager(manager), belief(), id(0), prevId(0) { + id = manager->getFreshId(); belief[state] = storm::utility::one(); risk = manager->getRisk(state); } template SparseBeliefState::SparseBeliefState(std::shared_ptr> const& manager, std::map const& belief, - std::size_t hash, ValueType const& risk) - : manager(manager), belief(belief), prestoredhash(hash), risk(risk) + std::size_t hash, ValueType const& risk, uint64_t prevId) + : manager(manager), belief(belief), prestoredhash(hash), risk(risk), id(0), prevId(prevId) { - // Intentionally left empty + id = manager->getFreshId(); } template @@ -74,6 +82,7 @@ namespace storm { template std::string SparseBeliefState::toString() const { std::stringstream sstr; + sstr << "id: " << id << "; "; bool first = true; for (auto const& beliefentry : belief) { if (!first) { @@ -83,6 +92,7 @@ namespace storm { } sstr << beliefentry.first << " : " << beliefentry.second; } + sstr << " (from " << prevId << ")"; return sstr.str(); } @@ -137,7 +147,67 @@ namespace storm { boost::hash_combine(newHash, entry.first); risk += entry.second * manager->getRisk(entry.first); } - return SparseBeliefState(manager, newBelief, newHash, risk); + return SparseBeliefState(manager, newBelief, newHash, risk, id); + } + + template + void SparseBeliefState::update(uint32_t newObservation, std::unordered_set>& previousBeliefs) const { + updateHelper({{}}, {storm::utility::zero()}, belief.begin(), newObservation, previousBeliefs); + } + + template + void SparseBeliefState::updateHelper(std::vector> const& partialBeliefs, std::vector const& sums, typename std::map::const_iterator nextStateIt, uint32_t newObservation, std::unordered_set>& previousBeliefs) const { + if(nextStateIt == belief.end()) { + for (uint64_t i = 0; i < partialBeliefs.size(); ++i) { + auto const& partialBelief = partialBeliefs[i]; + auto const& sum = sums[i]; + if (storm::utility::isZero(sum)) { + continue; + } + std::size_t newHash = 0; + ValueType risk = storm::utility::zero(); + std::map finalBelief; + for (auto &entry : partialBelief) { + assert(!storm::utility::isZero(sum)); + finalBelief[entry.first] = entry.second / sum; + //boost::hash_combine(newHash, std::hash()(entry.second)); + boost::hash_combine(newHash, entry.first); + risk += entry.second / sum * manager->getRisk(entry.first); + } + previousBeliefs.insert(SparseBeliefState(manager, finalBelief, newHash, risk, id)); + } + } else { + uint64_t state = nextStateIt->first; + auto newNextStateIt = nextStateIt; + newNextStateIt++; + std::vector> newPartialBeliefs; + std::vector newSums; + for (uint64_t i = 0; i < partialBeliefs.size(); ++i) { + + for (auto row = manager->getPomdp().getNondeterministicChoiceIndices()[state]; + row < manager->getPomdp().getNondeterministicChoiceIndices()[state + 1]; ++row) { + std::map newPartialBelief = partialBeliefs[i]; + ValueType newSum = sums[i]; + for (auto const &transition : manager->getPomdp().getTransitionMatrix().getRow(row)) { + if (newObservation != manager->getPomdp().getObservation(transition.getColumn())) { + continue; + } + + if (newPartialBelief.count(transition.getColumn()) == 0) { + newPartialBelief[transition.getColumn()] = transition.getValue() * nextStateIt->second; + } else { + newPartialBelief[transition.getColumn()] += transition.getValue() * nextStateIt->second; + } + newSum += transition.getValue() * nextStateIt->second; + + } + newPartialBeliefs.push_back(newPartialBelief); + newSums.push_back(newSum); + } + } + updateHelper(newPartialBeliefs, newSums, newNextStateIt, newObservation, previousBeliefs); + + } } @@ -164,14 +234,11 @@ namespace storm { bool NondeterministicBeliefTracker::track(uint64_t newObservation) { STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Cannot track without a belief (need to reset)."); std::unordered_set newBeliefs; - for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { - for (auto const& belief : beliefs) { - auto newBelief = belief.update(action, newObservation); - if (newBelief.isValid()) { - newBeliefs.insert(newBelief); - } - } + //for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { + for (auto const& belief : beliefs) { + belief.update(newObservation, newBeliefs); } + //} beliefs = newBeliefs; lastObservation = newObservation; return !beliefs.empty(); diff --git a/src/storm-pomdp/generator/NondeterministicBeliefTracker.h b/src/storm-pomdp/generator/NondeterministicBeliefTracker.h index 4712f38c6..462a33a93 100644 --- a/src/storm-pomdp/generator/NondeterministicBeliefTracker.h +++ b/src/storm-pomdp/generator/NondeterministicBeliefTracker.h @@ -11,10 +11,12 @@ namespace storm { uint64_t getActionsForObservation(uint32_t observation) const; ValueType getRisk(uint64_t) const; void setRiskPerState(std::vector const& risk); + uint64_t getFreshId(); private: storm::models::sparse::Pomdp const& pomdp; std::vector riskPerState; std::vector numberActionsPerObservation; + uint64_t beliefIdCounter = 0; }; template @@ -26,6 +28,7 @@ namespace storm { public: SparseBeliefState(std::shared_ptr> const& manager, uint64_t state); SparseBeliefState update(uint64_t action, uint32_t observation) const; + void update(uint32_t newObservation, std::unordered_set& previousBeliefs) const; std::size_t hash() const noexcept; ValueType get(uint64_t state) const; ValueType getRisk() const; @@ -34,12 +37,15 @@ namespace storm { friend bool operator==<>(SparseBeliefState const& lhs, SparseBeliefState const& rhs); private: - SparseBeliefState(std::shared_ptr> const& manager, std::map const& belief, std::size_t newHash, ValueType const& risk); + void updateHelper(std::vector> const& partialBeliefs, std::vector const& sums, typename std::map::const_iterator nextStateIt, uint32_t newObservation, std::unordered_set>& previousBeliefs) const; + SparseBeliefState(std::shared_ptr> const& manager, std::map const& belief, std::size_t newHash, ValueType const& risk, uint64_t prevId); std::shared_ptr> manager; std::map belief; // map is ordered for unique hashing. std::size_t prestoredhash = 0; ValueType risk; + uint64_t id; + uint64_t prevId; };