diff --git a/src/storm-pomdp/storage/BeliefGrid.h b/src/storm-pomdp/storage/BeliefGrid.h index b0661fcc5..61fe3b370 100644 --- a/src/storm-pomdp/storage/BeliefGrid.h +++ b/src/storm-pomdp/storage/BeliefGrid.h @@ -24,6 +24,19 @@ namespace storm { // Intentionally left empty } + void setRewardModel(boost::optional rewardModelName = boost::none) { + if (rewardModelName) { + auto const& rewardModel = pomdp.getRewardModel(rewardModelName.get()); + pomdpActionRewardVector = rewardModel.getTotalRewardVector(pomdp.getTransitionMatrix()); + } else { + setRewardModel(pomdp.getUniqueRewardModelName()); + } + } + + void unsetRewardModel() { + pomdpActionRewardVector.clear(); + } + struct Triangulation { std::vector gridPoints; std::vector weights; @@ -89,7 +102,11 @@ namespace storm { BeliefValueType sum = storm::utility::zero(); boost::optional observation; for (auto const& entry : belief) { - uintmax_t entryObservation = pomdp.getObservation(entry.first); + if (entry.first >= pomdp.getNumberOfStates()) { + STORM_LOG_ERROR("Belief does refer to non-existing pomdp state " << entry.first << "."); + return false; + } + uint64_t entryObservation = pomdp.getObservation(entry.first); if (observation) { if (observation.get() != entryObservation) { STORM_LOG_ERROR("Beliefsupport contains different observations."); @@ -176,6 +193,19 @@ namespace storm { return getOrAddGridPointId(belief); } + ValueType getBeliefActionReward(BeliefType const& belief, uint64_t const& localActionIndex) const { + STORM_LOG_ASSERT(!pomdpActionRewardVector.empty(), "Requested a reward although no reward model was specified."); + auto result = storm::utility::zero(); + auto const& choiceIndices = pomdp.getTransitionMatrix().getRowGroupIndices(); + for (auto const &entry : belief) { + uint64_t choiceIndex = choiceIndices[entry.first] + localActionIndex; + STORM_LOG_ASSERT(choiceIndex < choiceIndices[entry.first + 1], "Invalid local action index."); + STORM_LOG_ASSERT(choiceIndex < pomdpActionRewardVector.size(), "Invalid choice index."); + result += entry.second * pomdpActionRewardVector[choiceIndex]; + } + return result; + } + uint32_t getBeliefObservation(BeliefType belief) { STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief."); return pomdp.getObservation(belief.begin()->first); @@ -343,6 +373,7 @@ namespace storm { } PomdpType const& pomdp; + std::vector pomdpActionRewardVector; std::vector gridPoints; std::map gridPointToIdMap;