|
|
@ -24,6 +24,19 @@ namespace storm { |
|
|
|
// Intentionally left empty |
|
|
|
} |
|
|
|
|
|
|
|
void setRewardModel(boost::optional<std::string> rewardModelName = boost::none) { |
|
|
|
if (rewardModelName) { |
|
|
|
auto const& rewardModel = pomdp.getRewardModel(rewardModelName.get()); |
|
|
|
pomdpActionRewardVector = rewardModel.getTotalRewardVector(pomdp.getTransitionMatrix()); |
|
|
|
} else { |
|
|
|
setRewardModel(pomdp.getUniqueRewardModelName()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void unsetRewardModel() { |
|
|
|
pomdpActionRewardVector.clear(); |
|
|
|
} |
|
|
|
|
|
|
|
struct Triangulation { |
|
|
|
std::vector<BeliefId> gridPoints; |
|
|
|
std::vector<BeliefValueType> weights; |
|
|
@ -89,7 +102,11 @@ namespace storm { |
|
|
|
BeliefValueType sum = storm::utility::zero<ValueType>(); |
|
|
|
boost::optional<uint32_t> observation; |
|
|
|
for (auto const& entry : belief) { |
|
|
|
uintmax_t entryObservation = pomdp.getObservation(entry.first); |
|
|
|
if (entry.first >= pomdp.getNumberOfStates()) { |
|
|
|
STORM_LOG_ERROR("Belief does refer to non-existing pomdp state " << entry.first << "."); |
|
|
|
return false; |
|
|
|
} |
|
|
|
uint64_t entryObservation = pomdp.getObservation(entry.first); |
|
|
|
if (observation) { |
|
|
|
if (observation.get() != entryObservation) { |
|
|
|
STORM_LOG_ERROR("Beliefsupport contains different observations."); |
|
|
@ -176,6 +193,19 @@ namespace storm { |
|
|
|
return getOrAddGridPointId(belief); |
|
|
|
} |
|
|
|
|
|
|
|
ValueType getBeliefActionReward(BeliefType const& belief, uint64_t const& localActionIndex) const { |
|
|
|
STORM_LOG_ASSERT(!pomdpActionRewardVector.empty(), "Requested a reward although no reward model was specified."); |
|
|
|
auto result = storm::utility::zero<ValueType>(); |
|
|
|
auto const& choiceIndices = pomdp.getTransitionMatrix().getRowGroupIndices(); |
|
|
|
for (auto const &entry : belief) { |
|
|
|
uint64_t choiceIndex = choiceIndices[entry.first] + localActionIndex; |
|
|
|
STORM_LOG_ASSERT(choiceIndex < choiceIndices[entry.first + 1], "Invalid local action index."); |
|
|
|
STORM_LOG_ASSERT(choiceIndex < pomdpActionRewardVector.size(), "Invalid choice index."); |
|
|
|
result += entry.second * pomdpActionRewardVector[choiceIndex]; |
|
|
|
} |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t getBeliefObservation(BeliefType belief) { |
|
|
|
STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief."); |
|
|
|
return pomdp.getObservation(belief.begin()->first); |
|
|
@ -343,6 +373,7 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
PomdpType const& pomdp; |
|
|
|
std::vector<ValueType> pomdpActionRewardVector; |
|
|
|
|
|
|
|
std::vector<BeliefType> gridPoints; |
|
|
|
std::map<BeliefType, BeliefId> gridPointToIdMap; |
|
|
|