#include "storm-pomdp/generator/NondeterministicBeliefTracker.h" #include "storm/utility/ConstantsComparator.h" namespace storm { namespace generator { template BeliefStateManager::BeliefStateManager(storm::models::sparse::Pomdp const& pomdp) : pomdp(pomdp) { numberActionsPerObservation = std::vector(pomdp.getNrObservations(), 0); for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { numberActionsPerObservation[pomdp.getObservation(state)] = pomdp.getNumberOfChoices(state); } } template uint64_t BeliefStateManager::getActionsForObservation(uint32_t observation) const { return numberActionsPerObservation[observation]; } template ValueType BeliefStateManager::getRisk(uint64_t state) const { return riskPerState.at(state); } template storm::models::sparse::Pomdp const& BeliefStateManager::getPomdp() const { return pomdp; } template void BeliefStateManager::setRiskPerState(std::vector const& risk) { riskPerState = risk; } template SparseBeliefState::SparseBeliefState(std::shared_ptr> const& manager, uint64_t state) : manager(manager), belief() { belief[state] = storm::utility::one(); risk = manager->getRisk(state); } template SparseBeliefState::SparseBeliefState(std::shared_ptr> const& manager, std::map const& belief, std::size_t hash, ValueType const& risk) : manager(manager), belief(belief), prestoredhash(hash), risk(risk) { // Intentionally left empty } template ValueType SparseBeliefState::get(uint64_t state) const { return belief.at(state); } template ValueType SparseBeliefState::getRisk() const { return risk; } template std::size_t SparseBeliefState::hash() const noexcept { return prestoredhash; } template bool SparseBeliefState::isValid() const { return !belief.empty(); } template std::string SparseBeliefState::toString() const { std::stringstream sstr; bool first = true; for (auto const& beliefentry : belief) { if (!first) { sstr << ", "; } else { first = false; } sstr << beliefentry.first << " : " << beliefentry.second; } return sstr.str(); } template bool operator==(SparseBeliefState const& lhs, SparseBeliefState const& rhs) { if (lhs.hash() != rhs.hash()) { return false; } if (lhs.belief.size() != rhs.belief.size()) { return false; } storm::utility::ConstantsComparator cmp(0.00001, true); auto lhsIt = lhs.belief.begin(); auto rhsIt = rhs.belief.begin(); while(lhsIt != lhs.belief.end()) { if (lhsIt->first != rhsIt->first || !cmp.isEqual(lhsIt->second, rhsIt->second)) { return false; } ++lhsIt; ++rhsIt; } return true; //return std::equal(lhs.belief.begin(), lhs.belief.end(), rhs.belief.begin()); } template SparseBeliefState SparseBeliefState::update(uint64_t action, uint32_t observation) const { std::map newBelief; ValueType sum = storm::utility::zero(); for (auto const& beliefentry : belief) { assert(manager->getPomdp().getNumberOfChoices(beliefentry.first) > action); auto row = manager->getPomdp().getNondeterministicChoiceIndices()[beliefentry.first] + action; for (auto const& transition : manager->getPomdp().getTransitionMatrix().getRow(row)) { if (observation != manager->getPomdp().getObservation(transition.getColumn())) { continue; } if (newBelief.count(transition.getColumn()) == 0) { newBelief[transition.getColumn()] = transition.getValue() * beliefentry.second; } else { newBelief[transition.getColumn()] += transition.getValue() * beliefentry.second; } sum += transition.getValue() * beliefentry.second; } } std::size_t newHash = 0; ValueType risk = storm::utility::zero(); for(auto& entry : newBelief) { assert(!storm::utility::isZero(sum)); entry.second /= sum; //boost::hash_combine(newHash, std::hash()(entry.second)); boost::hash_combine(newHash, entry.first); risk += entry.second * manager->getRisk(entry.first); } return SparseBeliefState(manager, newBelief, newHash, risk); } template NondeterministicBeliefTracker::NondeterministicBeliefTracker(storm::models::sparse::Pomdp const& pomdp) : pomdp(pomdp), manager(std::make_shared>(pomdp)), beliefs() { // } template bool NondeterministicBeliefTracker::reset(uint32_t observation) { bool hit = false; for (auto state : pomdp.getInitialStates()) { if (observation == pomdp.getObservation(state)) { hit = true; beliefs.emplace(manager, state); } } lastObservation = observation; return hit; } template bool NondeterministicBeliefTracker::track(uint64_t newObservation) { STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Cannot track without a belief (need to reset)."); std::unordered_set newBeliefs; for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { for (auto const& belief : beliefs) { auto newBelief = belief.update(action, newObservation); if (newBelief.isValid()) { newBeliefs.insert(newBelief); } } } beliefs = newBeliefs; lastObservation = newObservation; return !beliefs.empty(); } template ValueType NondeterministicBeliefTracker::getCurrentRisk(bool max) { STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Risk is only defined for beliefs (run reset() first)."); ValueType result = beliefs.begin()->getRisk(); if (max) { for (auto const& belief : beliefs) { if (belief.getRisk() > result) { result = belief.getRisk(); } } } else { for (auto const& belief : beliefs) { if (belief.getRisk() < result) { result = belief.getRisk(); } } } return result; } template void NondeterministicBeliefTracker::setRisk(std::vector const& risk) { manager->setRiskPerState(risk); } template std::unordered_set const& NondeterministicBeliefTracker::getCurrentBeliefs() const { return beliefs; } template uint32_t NondeterministicBeliefTracker::getCurrentObservation() const { return lastObservation; } template class SparseBeliefState; template bool operator==(SparseBeliefState const&, SparseBeliefState const&); template class NondeterministicBeliefTracker>; } }