Browse Source

BeliefGrid: Adding support for rewards.

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
98bb48d3c5
  1. 33
      src/storm-pomdp/storage/BeliefGrid.h

33
src/storm-pomdp/storage/BeliefGrid.h

@ -24,6 +24,19 @@ namespace storm {
// Intentionally left empty
}
void setRewardModel(boost::optional<std::string> rewardModelName = boost::none) {
if (rewardModelName) {
auto const& rewardModel = pomdp.getRewardModel(rewardModelName.get());
pomdpActionRewardVector = rewardModel.getTotalRewardVector(pomdp.getTransitionMatrix());
} else {
setRewardModel(pomdp.getUniqueRewardModelName());
}
}
void unsetRewardModel() {
pomdpActionRewardVector.clear();
}
struct Triangulation {
std::vector<BeliefId> gridPoints;
std::vector<BeliefValueType> weights;
@ -89,7 +102,11 @@ namespace storm {
BeliefValueType sum = storm::utility::zero<ValueType>();
boost::optional<uint32_t> observation;
for (auto const& entry : belief) {
uintmax_t entryObservation = pomdp.getObservation(entry.first);
if (entry.first >= pomdp.getNumberOfStates()) {
STORM_LOG_ERROR("Belief does refer to non-existing pomdp state " << entry.first << ".");
return false;
}
uint64_t entryObservation = pomdp.getObservation(entry.first);
if (observation) {
if (observation.get() != entryObservation) {
STORM_LOG_ERROR("Beliefsupport contains different observations.");
@ -176,6 +193,19 @@ namespace storm {
return getOrAddGridPointId(belief);
}
ValueType getBeliefActionReward(BeliefType const& belief, uint64_t const& localActionIndex) const {
STORM_LOG_ASSERT(!pomdpActionRewardVector.empty(), "Requested a reward although no reward model was specified.");
auto result = storm::utility::zero<ValueType>();
auto const& choiceIndices = pomdp.getTransitionMatrix().getRowGroupIndices();
for (auto const &entry : belief) {
uint64_t choiceIndex = choiceIndices[entry.first] + localActionIndex;
STORM_LOG_ASSERT(choiceIndex < choiceIndices[entry.first + 1], "Invalid local action index.");
STORM_LOG_ASSERT(choiceIndex < pomdpActionRewardVector.size(), "Invalid choice index.");
result += entry.second * pomdpActionRewardVector[choiceIndex];
}
return result;
}
uint32_t getBeliefObservation(BeliefType belief) {
STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief.");
return pomdp.getObservation(belief.begin()->first);
@ -343,6 +373,7 @@ namespace storm {
}
PomdpType const& pomdp;
std::vector<ValueType> pomdpActionRewardVector;
std::vector<BeliefType> gridPoints;
std::map<BeliefType, BeliefId> gridPointToIdMap;

Loading…
Cancel
Save