Sebastian Junges
4 years ago
2 changed files with 290 additions and 0 deletions
-
202src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp
-
88src/storm-pomdp/generator/NondeterministicBeliefTracker.h
@ -0,0 +1,202 @@ |
|||
|
|||
#include "storm-pomdp/generator/NondeterministicBeliefTracker.h"
|
|||
|
|||
namespace storm { |
|||
namespace generator { |
|||
|
|||
template<typename ValueType> |
|||
BeliefStateManager<ValueType>::BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp) |
|||
: pomdp(pomdp) |
|||
{ |
|||
numberActionsPerObservation = std::vector<uint64_t>(pomdp.getNrObservations(), 0); |
|||
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|||
numberActionsPerObservation[pomdp.getObservation(state)] = pomdp.getNumberOfChoices(state); |
|||
} |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
uint64_t BeliefStateManager<ValueType>::getActionsForObservation(uint32_t observation) const { |
|||
return numberActionsPerObservation[observation]; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
ValueType BeliefStateManager<ValueType>::getRisk(uint64_t state) const { |
|||
return riskPerState.at(state); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
storm::models::sparse::Pomdp<ValueType> const& BeliefStateManager<ValueType>::getPomdp() const { |
|||
return pomdp; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void BeliefStateManager<ValueType>::setRiskPerState(std::vector<ValueType> const& risk) { |
|||
riskPerState = risk; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state) |
|||
: manager(manager), belief() |
|||
{ |
|||
belief[state] = storm::utility::one<ValueType>(); |
|||
risk = manager->getRisk(state); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
SparseBeliefState<ValueType>::SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, |
|||
std::size_t hash, ValueType const& risk) |
|||
: manager(manager), belief(belief), prestoredhash(hash), risk(risk) |
|||
{ |
|||
// Intentionally left empty
|
|||
} |
|||
|
|||
template<typename ValueType> |
|||
ValueType SparseBeliefState<ValueType>::get(uint64_t state) const { |
|||
return belief.at(state); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
ValueType SparseBeliefState<ValueType>::getRisk() const { |
|||
return risk; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::size_t SparseBeliefState<ValueType>::hash() const noexcept { |
|||
return prestoredhash; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool SparseBeliefState<ValueType>::isValid() const { |
|||
return !belief.empty(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::string SparseBeliefState<ValueType>::toString() const { |
|||
std::stringstream sstr; |
|||
bool first = true; |
|||
for (auto const& beliefentry : belief) { |
|||
if (!first) { |
|||
sstr << ", "; |
|||
} else { |
|||
first = false; |
|||
} |
|||
sstr << beliefentry.first << " : " << beliefentry.second; |
|||
} |
|||
return sstr.str(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs) { |
|||
return lhs.hash() == rhs.hash() && lhs.belief == rhs.belief; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
SparseBeliefState<ValueType> SparseBeliefState<ValueType>::update(uint64_t action, uint32_t observation) const { |
|||
std::map<uint64_t, ValueType> newBelief; |
|||
ValueType sum = storm::utility::zero<ValueType>(); |
|||
for (auto const& beliefentry : belief) { |
|||
assert(manager->getPomdp().getNumberOfChoices(beliefentry.first) > action); |
|||
auto row = manager->getPomdp().getNondeterministicChoiceIndices()[beliefentry.first] + action; |
|||
for (auto const& transition : manager->getPomdp().getTransitionMatrix().getRow(row)) { |
|||
if (observation != manager->getPomdp().getObservation(transition.getColumn())) { |
|||
continue; |
|||
} |
|||
|
|||
if (newBelief.count(transition.getColumn()) == 0) { |
|||
newBelief[transition.getColumn()] = transition.getValue() * beliefentry.second; |
|||
} else { |
|||
newBelief[transition.getColumn()] += transition.getValue() * beliefentry.second; |
|||
} |
|||
sum += transition.getValue() * beliefentry.second; |
|||
} |
|||
} |
|||
std::size_t newHash = 0; |
|||
ValueType risk = storm::utility::zero<ValueType>(); |
|||
for(auto& entry : newBelief) { |
|||
assert(!storm::utility::isZero(sum)); |
|||
entry.second /= sum; |
|||
boost::hash_combine(newHash, std::hash<ValueType>()(entry.second)); |
|||
boost::hash_combine(newHash, entry.first); |
|||
risk += entry.second * manager->getRisk(entry.first); |
|||
} |
|||
return SparseBeliefState<ValueType>(manager, newBelief, newHash, risk); |
|||
} |
|||
|
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
NondeterministicBeliefTracker<ValueType, BeliefState>::NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp) : |
|||
pomdp(pomdp), manager(std::make_shared<BeliefStateManager<ValueType>>(pomdp)), beliefs() { |
|||
//
|
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
bool NondeterministicBeliefTracker<ValueType, BeliefState>::reset(uint32_t observation) { |
|||
bool hit = false; |
|||
for (auto state : pomdp.getInitialStates()) { |
|||
if (observation == pomdp.getObservation(state)) { |
|||
hit = true; |
|||
beliefs.emplace(manager, state); |
|||
} |
|||
} |
|||
lastObservation = observation; |
|||
return hit; |
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
bool NondeterministicBeliefTracker<ValueType, BeliefState>::track(uint64_t newObservation) { |
|||
STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Cannot track without a belief (need to reset)."); |
|||
std::unordered_set<BeliefState> newBeliefs; |
|||
for (uint64_t action = 0; action < manager->getActionsForObservation(lastObservation); ++action) { |
|||
for (auto const& belief : beliefs) { |
|||
auto newBelief = belief.update(action, newObservation); |
|||
if (newBelief.isValid()) { |
|||
newBeliefs.insert(newBelief); |
|||
} |
|||
} |
|||
} |
|||
beliefs = newBeliefs; |
|||
lastObservation = newObservation; |
|||
return !beliefs.empty(); |
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
ValueType NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentRisk(bool max) { |
|||
STORM_LOG_THROW(!beliefs.empty(), storm::exceptions::InvalidOperationException, "Risk is only defined for beliefs (run reset() first)."); |
|||
ValueType result = beliefs.begin()->getRisk(); |
|||
if (max) { |
|||
for (auto const& belief : beliefs) { |
|||
if (belief.getRisk() > result) { |
|||
result = belief.getRisk(); |
|||
} |
|||
} |
|||
} else { |
|||
for (auto const& belief : beliefs) { |
|||
if (belief.getRisk() < result) { |
|||
result = belief.getRisk(); |
|||
} |
|||
} |
|||
} |
|||
return result; |
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
void NondeterministicBeliefTracker<ValueType, BeliefState>::setRisk(std::vector<ValueType> const& risk) { |
|||
manager->setRiskPerState(risk); |
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
std::unordered_set<BeliefState> const& NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentBeliefs() const { |
|||
return beliefs; |
|||
} |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
uint32_t NondeterministicBeliefTracker<ValueType, BeliefState>::getCurrentObservation() const { |
|||
return lastObservation; |
|||
} |
|||
|
|||
template class SparseBeliefState<double>; |
|||
template bool operator==(SparseBeliefState<double> const&, SparseBeliefState<double> const&); |
|||
template class NondeterministicBeliefTracker<double, SparseBeliefState<double>>; |
|||
|
|||
} |
|||
} |
@ -0,0 +1,88 @@ |
|||
#pragma once |
|||
#include "storm/models/sparse/Pomdp.h" |
|||
|
|||
namespace storm { |
|||
namespace generator { |
|||
template<typename ValueType> |
|||
class BeliefStateManager { |
|||
public: |
|||
BeliefStateManager(storm::models::sparse::Pomdp<ValueType> const& pomdp); |
|||
storm::models::sparse::Pomdp<ValueType> const& getPomdp() const; |
|||
uint64_t getActionsForObservation(uint32_t observation) const; |
|||
ValueType getRisk(uint64_t) const; |
|||
void setRiskPerState(std::vector<ValueType> const& risk); |
|||
private: |
|||
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
|||
std::vector<ValueType> riskPerState; |
|||
std::vector<uint64_t> numberActionsPerObservation; |
|||
}; |
|||
|
|||
template<typename ValueType> |
|||
class SparseBeliefState; |
|||
template<typename ValueType> |
|||
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
|||
template<typename ValueType> |
|||
class SparseBeliefState { |
|||
public: |
|||
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
|||
SparseBeliefState update(uint64_t action, uint32_t observation) const; |
|||
std::size_t hash() const noexcept; |
|||
ValueType get(uint64_t state) const; |
|||
ValueType getRisk() const; |
|||
std::string toString() const; |
|||
bool isValid() const; |
|||
|
|||
friend bool operator==<>(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
|||
private: |
|||
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, std::map<uint64_t, ValueType> const& belief, std::size_t newHash, ValueType const& risk); |
|||
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
|||
|
|||
std::map<uint64_t, ValueType> belief; // map is ordered for unique hashing. |
|||
std::size_t prestoredhash = 0; |
|||
ValueType risk; |
|||
|
|||
}; |
|||
|
|||
|
|||
template<typename ValueType> |
|||
class ObservationDenseBeliefState { |
|||
public: |
|||
ObservationDenseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
|||
ObservationDenseBeliefState update(uint64_t action, uint32_t observation) const; |
|||
private: |
|||
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
|||
std::unordered_map<uint64_t, ValueType> belief; |
|||
|
|||
void normalize(); |
|||
}; |
|||
|
|||
template<typename ValueType, typename BeliefState> |
|||
class NondeterministicBeliefTracker { |
|||
public: |
|||
NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp); |
|||
bool reset(uint32_t observation); |
|||
bool track(uint64_t newObservation); |
|||
std::unordered_set<BeliefState> const& getCurrentBeliefs() const; |
|||
uint32_t getCurrentObservation() const; |
|||
ValueType getCurrentRisk(bool max=true); |
|||
void setRisk(std::vector<ValueType> const& risk); |
|||
|
|||
private: |
|||
|
|||
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
|||
std::shared_ptr<BeliefStateManager<ValueType>> manager; |
|||
std::unordered_set<BeliefState> beliefs; |
|||
uint32_t lastObservation; |
|||
}; |
|||
} |
|||
} |
|||
|
|||
// |
|||
namespace std { |
|||
template<typename T> |
|||
struct hash<storm::generator::SparseBeliefState<T>> { |
|||
std::size_t operator()(storm::generator::SparseBeliefState<T> const& s) const noexcept { |
|||
return s.hash(); |
|||
} |
|||
}; |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue