2 changed files with 290 additions and 0 deletions
			
			
		- 
					202src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp
- 
					88src/storm-pomdp/generator/NondeterministicBeliefTracker.h
| @ -0,0 +1,202 @@ | |||
| 
 | |||
| #include "storm-pomdp/generator/NondeterministicBeliefTracker.h"
 | |||
| 
 | |||
| namespace storm { | |||
|     namespace generator { | |||
| 
 | |||
|         template<typename ValueType> | |||
|         BeliefStateManager<ValueType>::BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp) | |||
|         : pomdp(pomdp) | |||
|         { | |||
|             numberActionsPerObservation = std::vector<uint64_t>(pomdp.getNrObservations(), 0); | |||
|             for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { | |||
|                 numberActionsPerObservation[pomdp.getObservation(state)] = pomdp.getNumberOfChoices(state); | |||
|             } | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         uint64_t BeliefStateManager<ValueType>::getActionsForObservation(uint32_t observation) const { | |||
|             return numberActionsPerObservation[observation]; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         ValueType BeliefStateManager<ValueType>::getRisk(uint64_t state) const { | |||
|             return riskPerState.at(state); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         storm::models::sparse::Pomdp<ValueType> const& BeliefStateManager<ValueType>::getPomdp() const { | |||
|             return pomdp; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         void BeliefStateManager<ValueType>::setRiskPerState(std::vector<ValueType> const& risk) { | |||
|             riskPerState = risk; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state) | |||
|         : manager(manager), belief() | |||
|         { | |||
|             belief[state] = storm::utility::one<ValueType>(); | |||
|             risk = manager->getRisk(state); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, | |||
|                                                         std::size_t hash, ValueType const& risk) | |||
|         : manager(manager), belief(belief), prestoredhash(hash), risk(risk) | |||
|         { | |||
|             // Intentionally left empty
 | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         ValueType SparseBeliefState<ValueType>::get(uint64_t state) const { | |||
|             return belief.at(state); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         ValueType SparseBeliefState<ValueType>::getRisk() const { | |||
|             return risk; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         std::size_t SparseBeliefState<ValueType>::hash() const noexcept { | |||
|             return prestoredhash; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         bool SparseBeliefState<ValueType>::isValid() const { | |||
|             return !belief.empty(); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         std::string SparseBeliefState<ValueType>::toString() const { | |||
|             std::stringstream sstr; | |||
|             bool first = true; | |||
|             for (auto const& beliefentry : belief) { | |||
|                 if (!first) { | |||
|                     sstr << ", "; | |||
|                 } else { | |||
|                     first = false; | |||
|                 } | |||
|                 sstr << beliefentry.first << " : " << beliefentry.second; | |||
|             } | |||
|             return sstr.str(); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs) { | |||
|             return lhs.hash() == rhs.hash() && lhs.belief == rhs.belief; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType> | |||
|         SparseBeliefState<ValueType> SparseBeliefState<ValueType>::update(uint64_t action, uint32_t observation) const { | |||
|             std::map<uint64_t, ValueType> newBelief; | |||
|             ValueType sum = storm::utility::zero<ValueType>(); | |||
|             for (auto const& beliefentry : belief) { | |||
|                 assert(manager->getPomdp().getNumberOfChoices(beliefentry.first) > action); | |||
|                 auto row = manager->getPomdp().getNondeterministicChoiceIndices()[beliefentry.first] + action; | |||
|                 for (auto const& transition : manager->getPomdp().getTransitionMatrix().getRow(row)) { | |||
|                     if (observation != manager->getPomdp().getObservation(transition.getColumn())) { | |||
|                         continue; | |||
|                     } | |||
| 
 | |||
|                     if (newBelief.count(transition.getColumn()) == 0) { | |||
|                         newBelief[transition.getColumn()] = transition.getValue() * beliefentry.second; | |||
|                     } else { | |||
|                         newBelief[transition.getColumn()] += transition.getValue() * beliefentry.second; | |||
|                     } | |||
|                     sum += transition.getValue() * beliefentry.second; | |||
|                 } | |||
|             } | |||
|             std::size_t newHash = 0; | |||
|             ValueType risk = storm::utility::zero<ValueType>(); | |||
|             for(auto& entry : newBelief) { | |||
|                 assert(!storm::utility::isZero(sum)); | |||
|                 entry.second /= sum; | |||
|                 boost::hash_combine(newHash, std::hash<ValueType>()(entry.second)); | |||
|                 boost::hash_combine(newHash, entry.first); | |||
|                 risk += entry.second * manager->getRisk(entry.first); | |||
|             } | |||
|             return SparseBeliefState<ValueType>(manager, newBelief, newHash, risk); | |||
|         } | |||
| 
 | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         NondeterministicBeliefTracker<ValueType, BeliefState>::NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp) : | |||
|         pomdp(pomdp), manager(std::make_shared<BeliefStateManager<ValueType>>(pomdp)), beliefs() { | |||
|             //
 | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         bool NondeterministicBeliefTracker<ValueType, BeliefState>::reset(uint32_t observation) { | |||
|             bool hit = false; | |||
|             for (auto state : pomdp.getInitialStates()) { | |||
|                 if (observation == pomdp.getObservation(state)) { | |||
|                     hit = true; | |||
|                     beliefs.emplace(manager, state); | |||
|                 } | |||
|             } | |||
|             lastObservation = observation; | |||
|             return hit; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         bool NondeterministicBeliefTracker<ValueType, BeliefState>::track(uint64_t newObservation) { | |||
|             STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Cannot track without a belief (need to reset)."); | |||
|             std::unordered_set<BeliefState> newBeliefs; | |||
|             for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { | |||
|                 for (auto const& belief : beliefs) { | |||
|                     auto newBelief = belief.update(action, newObservation); | |||
|                     if (newBelief.isValid()) { | |||
|                         newBeliefs.insert(newBelief); | |||
|                     } | |||
|                 } | |||
|             } | |||
|             beliefs = newBeliefs; | |||
|             lastObservation = newObservation; | |||
|             return !beliefs.empty(); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         ValueType NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentRisk(bool max) { | |||
|             STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Risk is only defined for beliefs (run reset() first)."); | |||
|             ValueType result = beliefs.begin()->getRisk(); | |||
|             if (max) { | |||
|                 for (auto const& belief : beliefs) { | |||
|                     if (belief.getRisk() > result) { | |||
|                         result = belief.getRisk(); | |||
|                     } | |||
|                 } | |||
|             } else { | |||
|                 for (auto const& belief : beliefs) { | |||
|                     if (belief.getRisk() < result) { | |||
|                         result = belief.getRisk(); | |||
|                     } | |||
|                 } | |||
|             } | |||
|             return result; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         void NondeterministicBeliefTracker<ValueType, BeliefState>::setRisk(std::vector<ValueType> const& risk) { | |||
|             manager->setRiskPerState(risk); | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         std::unordered_set<BeliefState> const& NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentBeliefs() const { | |||
|             return beliefs; | |||
|         } | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         uint32_t NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentObservation() const { | |||
|             return lastObservation; | |||
|         } | |||
| 
 | |||
|         template class SparseBeliefState<double>; | |||
|         template bool operator==(SparseBeliefState<double> const&, SparseBeliefState<double> const&); | |||
|         template class NondeterministicBeliefTracker<double, SparseBeliefState<double>>; | |||
| 
 | |||
|     } | |||
| } | |||
| @ -0,0 +1,88 @@ | |||
| #pragma once | |||
| #include "storm/models/sparse/Pomdp.h" | |||
| 
 | |||
| namespace storm { | |||
|     namespace generator { | |||
|         template<typename ValueType> | |||
|         class BeliefStateManager { | |||
|         public: | |||
|             BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp); | |||
|             storm::models::sparse::Pomdp<ValueType> const& getPomdp() const; | |||
|             uint64_t getActionsForObservation(uint32_t observation) const; | |||
|             ValueType getRisk(uint64_t) const; | |||
|             void setRiskPerState(std::vector<ValueType> const& risk); | |||
|         private: | |||
|             storm::models::sparse::Pomdp<ValueType> const& pomdp; | |||
|             std::vector<ValueType> riskPerState; | |||
|             std::vector<uint64_t> numberActionsPerObservation; | |||
|         }; | |||
| 
 | |||
|         template<typename ValueType> | |||
|         class SparseBeliefState; | |||
|         template<typename ValueType> | |||
|         bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); | |||
|         template<typename ValueType> | |||
|         class SparseBeliefState { | |||
|         public: | |||
|             SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); | |||
|             SparseBeliefState update(uint64_t action, uint32_t  observation) const; | |||
|             std::size_t hash() const noexcept; | |||
|             ValueType get(uint64_t state) const; | |||
|             ValueType getRisk() const; | |||
|             std::string toString() const; | |||
|             bool isValid() const; | |||
| 
 | |||
|             friend bool operator==<>(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); | |||
|         private: | |||
|             SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, std::size_t newHash,  ValueType const& risk); | |||
|             std::shared_ptr<BeliefStateManager<ValueType>> manager; | |||
| 
 | |||
|             std::map<uint64_t, ValueType> belief; // map is ordered for unique hashing. | |||
|             std::size_t prestoredhash = 0; | |||
|             ValueType risk; | |||
| 
 | |||
|         }; | |||
| 
 | |||
| 
 | |||
|         template<typename ValueType> | |||
|         class ObservationDenseBeliefState { | |||
|         public: | |||
|             ObservationDenseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); | |||
|             ObservationDenseBeliefState update(uint64_t action, uint32_t observation) const; | |||
|         private: | |||
|             std::shared_ptr<BeliefStateManager<ValueType>> manager; | |||
|             std::unordered_map<uint64_t, ValueType> belief; | |||
| 
 | |||
|             void normalize(); | |||
|         }; | |||
| 
 | |||
|         template<typename ValueType, typename BeliefState> | |||
|         class NondeterministicBeliefTracker { | |||
|         public: | |||
|             NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp); | |||
|             bool reset(uint32_t observation); | |||
|             bool track(uint64_t newObservation); | |||
|             std::unordered_set<BeliefState> const& getCurrentBeliefs() const; | |||
|             uint32_t getCurrentObservation() const; | |||
|             ValueType getCurrentRisk(bool max=true); | |||
|             void setRisk(std::vector<ValueType> const& risk); | |||
| 
 | |||
|         private: | |||
| 
 | |||
|             storm::models::sparse::Pomdp<ValueType> const& pomdp; | |||
|             std::shared_ptr<BeliefStateManager<ValueType>> manager; | |||
|             std::unordered_set<BeliefState> beliefs; | |||
|             uint32_t lastObservation; | |||
|         }; | |||
|     } | |||
| } | |||
| 
 | |||
| // | |||
| namespace std { | |||
|     template<typename T> | |||
|     struct hash<storm::generator::SparseBeliefState<T>> { | |||
|         std::size_t operator()(storm::generator::SparseBeliefState<T> const& s) const noexcept { | |||
|             return s.hash(); | |||
|         } | |||
|     }; | |||
| } | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue