|
@ -1,9 +1,12 @@ |
|
|
#pragma once |
|
|
#pragma once |
|
|
#include "storm/models/sparse/Pomdp.h" |
|
|
#include "storm/models/sparse/Pomdp.h" |
|
|
#include "storm/adapters/EigenAdapter.h" |
|
|
|
|
|
|
|
|
|
|
|
namespace storm { |
|
|
namespace storm { |
|
|
namespace generator { |
|
|
namespace generator { |
|
|
|
|
|
/** |
|
|
|
|
|
* This class keeps track of common information of a set of beliefs. |
|
|
|
|
|
* It also keeps a reference to the POMDP. The manager is referenced by all beliefs. |
|
|
|
|
|
*/ |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
class BeliefStateManager { |
|
|
class BeliefStateManager { |
|
|
public: |
|
|
public: |
|
@ -28,22 +31,41 @@ namespace storm { |
|
|
std::vector<std::vector<uint64_t>> statePerObservationAndOffset; |
|
|
std::vector<std::vector<uint64_t>> statePerObservationAndOffset; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
class SparseBeliefState; |
|
|
class SparseBeliefState; |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
|
|
bool operator==(SparseBeliefState<ValueType> const& lhs, SparseBeliefState<ValueType> const& rhs); |
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
* SparseBeliefState stores beliefs in a sparse format. |
|
|
|
|
|
*/ |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
class SparseBeliefState { |
|
|
class SparseBeliefState { |
|
|
public: |
|
|
public: |
|
|
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
|
|
SparseBeliefState(std::shared_ptr<BeliefStateManager<ValueType>> const& manager, uint64_t state); |
|
|
SparseBeliefState update(uint64_t action, uint32_t observation) const; |
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
* Update the belief using the new observation |
|
|
|
|
|
* @param newObservation |
|
|
|
|
|
* @param previousBeliefs put the new belief in this set |
|
|
|
|
|
*/ |
|
|
void update(uint32_t newObservation, std::unordered_set<SparseBeliefState>& previousBeliefs) const; |
|
|
void update(uint32_t newObservation, std::unordered_set<SparseBeliefState>& previousBeliefs) const; |
|
|
std::size_t hash() const noexcept; |
|
|
std::size_t hash() const noexcept; |
|
|
|
|
|
/** |
|
|
|
|
|
* Get the estimate to be in the given state |
|
|
|
|
|
* @param state |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
ValueType get(uint64_t state) const; |
|
|
ValueType get(uint64_t state) const; |
|
|
|
|
|
/** |
|
|
|
|
|
* Get the weighted risk |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
ValueType getRisk() const; |
|
|
ValueType getRisk() const; |
|
|
|
|
|
|
|
|
|
|
|
// Various getters |
|
|
std::string toString() const; |
|
|
std::string toString() const; |
|
|
bool isValid() const; |
|
|
bool isValid() const; |
|
|
Eigen::Matrix<ValueType, Eigen::Dynamic, 1> toEigenVector(storm::storage::BitVector const& support) const; |
|
|
|
|
|
uint64_t getSupportSize() const; |
|
|
uint64_t getSupportSize() const; |
|
|
void setSupport(storm::storage::BitVector&) const; |
|
|
void setSupport(storm::storage::BitVector&) const; |
|
|
std::map<uint64_t, ValueType> const& getBeliefMap() const; |
|
|
std::map<uint64_t, ValueType> const& getBeliefMap() const; |
|
@ -61,7 +83,9 @@ namespace storm { |
|
|
uint64_t prevId; |
|
|
uint64_t prevId; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
* ObservationDenseBeliefState stores beliefs in a dense format (per observation). |
|
|
|
|
|
*/ |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
class ObservationDenseBeliefState; |
|
|
class ObservationDenseBeliefState; |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
@ -76,7 +100,6 @@ namespace storm { |
|
|
ValueType get(uint64_t state) const; |
|
|
ValueType get(uint64_t state) const; |
|
|
ValueType getRisk() const; |
|
|
ValueType getRisk() const; |
|
|
std::string toString() const; |
|
|
std::string toString() const; |
|
|
Eigen::Matrix<ValueType, Eigen::Dynamic, 1> toEigenVector(storm::storage::BitVector const& support) const; |
|
|
|
|
|
uint64_t getSupportSize() const; |
|
|
uint64_t getSupportSize() const; |
|
|
void setSupport(storm::storage::BitVector&) const; |
|
|
void setSupport(storm::storage::BitVector&) const; |
|
|
friend bool operator==<>(ObservationDenseBeliefState<ValueType> const& lhs, ObservationDenseBeliefState<ValueType> const& rhs); |
|
|
friend bool operator==<>(ObservationDenseBeliefState<ValueType> const& lhs, ObservationDenseBeliefState<ValueType> const& rhs); |
|
@ -93,10 +116,15 @@ namespace storm { |
|
|
uint64_t prevId; |
|
|
uint64_t prevId; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
|
* This tracker implements state estimation for POMDPs. |
|
|
|
|
|
* This corresponds to forward filtering in Junges, Torfah, Seshia. |
|
|
|
|
|
* |
|
|
|
|
|
* @tparam ValueType How are probabilities stored |
|
|
|
|
|
* @tparam BeliefState What format to use for beliefs |
|
|
|
|
|
*/ |
|
|
template<typename ValueType, typename BeliefState> |
|
|
template<typename ValueType, typename BeliefState> |
|
|
class NondeterministicBeliefTracker { |
|
|
class NondeterministicBeliefTracker { |
|
|
|
|
|
|
|
|
public: |
|
|
public: |
|
|
struct Options { |
|
|
struct Options { |
|
|
uint64_t trackTimeOut = 0; |
|
|
uint64_t trackTimeOut = 0; |
|
@ -104,15 +132,59 @@ namespace storm { |
|
|
ValueType wiggle; // tolerance, anything above 0 means that we are incomplete. |
|
|
ValueType wiggle; // tolerance, anything above 0 means that we are incomplete. |
|
|
}; |
|
|
}; |
|
|
NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp, typename NondeterministicBeliefTracker<ValueType, BeliefState>::Options options = Options()); |
|
|
NondeterministicBeliefTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp, typename NondeterministicBeliefTracker<ValueType, BeliefState>::Options options = Options()); |
|
|
|
|
|
/** |
|
|
|
|
|
* Start with a new trace. |
|
|
|
|
|
* @param observation The initial observation to start with. |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
bool reset(uint32_t observation); |
|
|
bool reset(uint32_t observation); |
|
|
|
|
|
/** |
|
|
|
|
|
* Extend the observed trace with the new observation |
|
|
|
|
|
* @param newObservation |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
bool track(uint64_t newObservation); |
|
|
bool track(uint64_t newObservation); |
|
|
|
|
|
/** |
|
|
|
|
|
* Provides access to the current beliefs. |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
std::unordered_set<BeliefState> const& getCurrentBeliefs() const; |
|
|
std::unordered_set<BeliefState> const& getCurrentBeliefs() const; |
|
|
|
|
|
/** |
|
|
|
|
|
* What was the last obervation that we made? |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
uint32_t getCurrentObservation() const; |
|
|
uint32_t getCurrentObservation() const; |
|
|
|
|
|
/** |
|
|
|
|
|
* How many beliefs are we currently tracking? |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
uint64_t getNumberOfBeliefs() const; |
|
|
uint64_t getNumberOfBeliefs() const; |
|
|
|
|
|
/** |
|
|
|
|
|
* What is the (worst-case/best-case) risk over all beliefs |
|
|
|
|
|
* @param max Should we take the max or the min? |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
ValueType getCurrentRisk(bool max=true); |
|
|
ValueType getCurrentRisk(bool max=true); |
|
|
|
|
|
/** |
|
|
|
|
|
* Sets the state-risk to use for all beliefs. |
|
|
|
|
|
* @param risk |
|
|
|
|
|
*/ |
|
|
void setRisk(std::vector<ValueType> const& risk); |
|
|
void setRisk(std::vector<ValueType> const& risk); |
|
|
|
|
|
/** |
|
|
|
|
|
* What is the dimension of the current set of beliefs, i.e., |
|
|
|
|
|
* what is the number of states we could possibly be in? |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
uint64_t getCurrentDimension() const; |
|
|
uint64_t getCurrentDimension() const; |
|
|
|
|
|
/** |
|
|
|
|
|
* Apply reductions to the belief state |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
uint64_t reduce(); |
|
|
uint64_t reduce(); |
|
|
|
|
|
/** |
|
|
|
|
|
* Did we time out during the computation? |
|
|
|
|
|
* @return |
|
|
|
|
|
*/ |
|
|
bool hasTimedOut() const; |
|
|
bool hasTimedOut() const; |
|
|
|
|
|
|
|
|
private: |
|
|
private: |
|
|