Sebastian Junges
4 years ago
2 changed files with 290 additions and 0 deletions
-
202src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp
-
88src/storm-pomdp/generator/NondeterministicBeliefTracker.h
@ -0,0 +1,202 @@ |
|||||
|
|
||||
|
#include "storm-pomdp/generator/NondeterministicBeliefTracker.h"
|
||||
|
|
||||
|
namespace storm { |
||||
|
namespace generator { |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
BeliefStateManager<ValueType>::BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp) |
||||
|
: pomdp(pomdp) |
||||
|
{ |
||||
|
numberActionsPerObservation = std::vector<uint64_t>(pomdp.getNrObservations(), 0); |
||||
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
||||
|
numberActionsPerObservation[pomdp.getObservation(state)] = pomdp.getNumberOfChoices(state); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
uint64_t BeliefStateManager<ValueType>::getActionsForObservation(uint32_t observation) const { |
||||
|
return numberActionsPerObservation[observation]; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
ValueType BeliefStateManager<ValueType>::getRisk(uint64_t state) const { |
||||
|
return riskPerState.at(state); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
storm::models::sparse::Pomdp<ValueType> const& BeliefStateManager<ValueType>::getPomdp() const { |
||||
|
return pomdp; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
void BeliefStateManager<ValueType>::setRiskPerState(std::vector<ValueType> const& risk) { |
||||
|
riskPerState = risk; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state) |
||||
|
: manager(manager), belief() |
||||
|
{ |
||||
|
belief[state] = storm::utility::one<ValueType>(); |
||||
|
risk = manager->getRisk(state); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, |
||||
|
std::size_t hash, ValueType const& risk) |
||||
|
: manager(manager), belief(belief), prestoredhash(hash), risk(risk) |
||||
|
{ |
||||
|
// Intentionally left empty
|
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
ValueType SparseBeliefState<ValueType>::get(uint64_t state) const { |
||||
|
return belief.at(state); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
ValueType SparseBeliefState<ValueType>::getRisk() const { |
||||
|
return risk; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
std::size_t SparseBeliefState<ValueType>::hash() const noexcept { |
||||
|
return prestoredhash; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
bool SparseBeliefState<ValueType>::isValid() const { |
||||
|
return !belief.empty(); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
std::string SparseBeliefState<ValueType>::toString() const { |
||||
|
std::stringstream sstr; |
||||
|
bool first = true; |
||||
|
for (auto const& beliefentry : belief) { |
||||
|
if (!first) { |
||||
|
sstr << ", "; |
||||
|
} else { |
||||
|
first = false; |
||||
|
} |
||||
|
sstr << beliefentry.first << " : " << beliefentry.second; |
||||
|
} |
||||
|
return sstr.str(); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs) { |
||||
|
return lhs.hash() == rhs.hash() && lhs.belief == rhs.belief; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
SparseBeliefState<ValueType> SparseBeliefState<ValueType>::update(uint64_t action, uint32_t observation) const { |
||||
|
std::map<uint64_t, ValueType> newBelief; |
||||
|
ValueType sum = storm::utility::zero<ValueType>(); |
||||
|
for (auto const& beliefentry : belief) { |
||||
|
assert(manager->getPomdp().getNumberOfChoices(beliefentry.first) > action); |
||||
|
auto row = manager->getPomdp().getNondeterministicChoiceIndices()[beliefentry.first] + action; |
||||
|
for (auto const& transition : manager->getPomdp().getTransitionMatrix().getRow(row)) { |
||||
|
if (observation != manager->getPomdp().getObservation(transition.getColumn())) { |
||||
|
continue; |
||||
|
} |
||||
|
|
||||
|
if (newBelief.count(transition.getColumn()) == 0) { |
||||
|
newBelief[transition.getColumn()] = transition.getValue() * beliefentry.second; |
||||
|
} else { |
||||
|
newBelief[transition.getColumn()] += transition.getValue() * beliefentry.second; |
||||
|
} |
||||
|
sum += transition.getValue() * beliefentry.second; |
||||
|
} |
||||
|
} |
||||
|
std::size_t newHash = 0; |
||||
|
ValueType risk = storm::utility::zero<ValueType>(); |
||||
|
for(auto& entry : newBelief) { |
||||
|
assert(!storm::utility::isZero(sum)); |
||||
|
entry.second /= sum; |
||||
|
boost::hash_combine(newHash, std::hash<ValueType>()(entry.second)); |
||||
|
boost::hash_combine(newHash, entry.first); |
||||
|
risk += entry.second * manager->getRisk(entry.first); |
||||
|
} |
||||
|
return SparseBeliefState<ValueType>(manager, newBelief, newHash, risk); |
||||
|
} |
||||
|
|
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
NondeterministicBeliefTracker<ValueType, BeliefState>::NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp) : |
||||
|
pomdp(pomdp), manager(std::make_shared<BeliefStateManager<ValueType>>(pomdp)), beliefs() { |
||||
|
//
|
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
bool NondeterministicBeliefTracker<ValueType, BeliefState>::reset(uint32_t observation) { |
||||
|
bool hit = false; |
||||
|
for (auto state : pomdp.getInitialStates()) { |
||||
|
if (observation == pomdp.getObservation(state)) { |
||||
|
hit = true; |
||||
|
beliefs.emplace(manager, state); |
||||
|
} |
||||
|
} |
||||
|
lastObservation = observation; |
||||
|
return hit; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
bool NondeterministicBeliefTracker<ValueType, BeliefState>::track(uint64_t newObservation) { |
||||
|
STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Cannot track without a belief (need to reset)."); |
||||
|
std::unordered_set<BeliefState> newBeliefs; |
||||
|
for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { |
||||
|
for (auto const& belief : beliefs) { |
||||
|
auto newBelief = belief.update(action, newObservation); |
||||
|
if (newBelief.isValid()) { |
||||
|
newBeliefs.insert(newBelief); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
beliefs = newBeliefs; |
||||
|
lastObservation = newObservation; |
||||
|
return !beliefs.empty(); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
ValueType NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentRisk(bool max) { |
||||
|
STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Risk is only defined for beliefs (run reset() first)."); |
||||
|
ValueType result = beliefs.begin()->getRisk(); |
||||
|
if (max) { |
||||
|
for (auto const& belief : beliefs) { |
||||
|
if (belief.getRisk() > result) { |
||||
|
result = belief.getRisk(); |
||||
|
} |
||||
|
} |
||||
|
} else { |
||||
|
for (auto const& belief : beliefs) { |
||||
|
if (belief.getRisk() < result) { |
||||
|
result = belief.getRisk(); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
return result; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
void NondeterministicBeliefTracker<ValueType, BeliefState>::setRisk(std::vector<ValueType> const& risk) { |
||||
|
manager->setRiskPerState(risk); |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
std::unordered_set<BeliefState> const& NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentBeliefs() const { |
||||
|
return beliefs; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
uint32_t NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentObservation() const { |
||||
|
return lastObservation; |
||||
|
} |
||||
|
|
||||
|
template class SparseBeliefState<double>; |
||||
|
template bool operator==(SparseBeliefState<double> const&, SparseBeliefState<double> const&); |
||||
|
template class NondeterministicBeliefTracker<double, SparseBeliefState<double>>; |
||||
|
|
||||
|
} |
||||
|
} |
@ -0,0 +1,88 @@ |
|||||
|
#pragma once |
||||
|
#include "storm/models/sparse/Pomdp.h" |
||||
|
|
||||
|
namespace storm { |
||||
|
namespace generator { |
||||
|
template<typename ValueType> |
||||
|
class BeliefStateManager { |
||||
|
public: |
||||
|
BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp); |
||||
|
storm::models::sparse::Pomdp<ValueType> const& getPomdp() const; |
||||
|
uint64_t getActionsForObservation(uint32_t observation) const; |
||||
|
ValueType getRisk(uint64_t) const; |
||||
|
void setRiskPerState(std::vector<ValueType> const& risk); |
||||
|
private: |
||||
|
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
||||
|
std::vector<ValueType> riskPerState; |
||||
|
std::vector<uint64_t> numberActionsPerObservation; |
||||
|
}; |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
class SparseBeliefState; |
||||
|
template<typename ValueType> |
||||
|
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
||||
|
template<typename ValueType> |
||||
|
class SparseBeliefState { |
||||
|
public: |
||||
|
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
||||
|
SparseBeliefState update(uint64_t action, uint32_t observation) const; |
||||
|
std::size_t hash() const noexcept; |
||||
|
ValueType get(uint64_t state) const; |
||||
|
ValueType getRisk() const; |
||||
|
std::string toString() const; |
||||
|
bool isValid() const; |
||||
|
|
||||
|
friend bool operator==<>(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
||||
|
private: |
||||
|
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, std::size_t newHash, ValueType const& risk); |
||||
|
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
||||
|
|
||||
|
std::map<uint64_t, ValueType> belief; // map is ordered for unique hashing. |
||||
|
std::size_t prestoredhash = 0; |
||||
|
ValueType risk; |
||||
|
|
||||
|
}; |
||||
|
|
||||
|
|
||||
|
template<typename ValueType> |
||||
|
class ObservationDenseBeliefState { |
||||
|
public: |
||||
|
ObservationDenseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
||||
|
ObservationDenseBeliefState update(uint64_t action, uint32_t observation) const; |
||||
|
private: |
||||
|
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
||||
|
std::unordered_map<uint64_t, ValueType> belief; |
||||
|
|
||||
|
void normalize(); |
||||
|
}; |
||||
|
|
||||
|
template<typename ValueType, typename BeliefState> |
||||
|
class NondeterministicBeliefTracker { |
||||
|
public: |
||||
|
NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp); |
||||
|
bool reset(uint32_t observation); |
||||
|
bool track(uint64_t newObservation); |
||||
|
std::unordered_set<BeliefState> const& getCurrentBeliefs() const; |
||||
|
uint32_t getCurrentObservation() const; |
||||
|
ValueType getCurrentRisk(bool max=true); |
||||
|
void setRisk(std::vector<ValueType> const& risk); |
||||
|
|
||||
|
private: |
||||
|
|
||||
|
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
||||
|
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
||||
|
std::unordered_set<BeliefState> beliefs; |
||||
|
uint32_t lastObservation; |
||||
|
}; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
// |
||||
|
namespace std { |
||||
|
template<typename T> |
||||
|
struct hash<storm::generator::SparseBeliefState<T>> { |
||||
|
std::size_t operator()(storm::generator::SparseBeliefState<T> const& s) const noexcept { |
||||
|
return s.hash(); |
||||
|
} |
||||
|
}; |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue