Browse Source

BeliefManager: Made several methods private to hide the actual BeliefType.

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
26864067cf
  1. 164
      src/storm-pomdp/storage/BeliefManager.h

164
src/storm-pomdp/storage/BeliefManager.h

@ -44,6 +44,92 @@ namespace storm {
} }
}; };
BeliefId noId() const {
return std::numeric_limits<BeliefId>::max();
}
bool isEqual(BeliefId const& first, BeliefId const& second) const {
return isEqual(getBelief(first), getBelief(second));
}
std::string toString(BeliefId const& beliefId) const {
return toString(getBelief(beliefId));
}
std::string toString(Triangulation const& t) const {
std::stringstream str;
str << "(\n";
for (uint64_t i = 0; i < t.size(); ++i) {
str << "\t" << t.weights[i] << " * \t" << toString(getBelief(t.gridPoints[i])) << "\n";
}
str <<")\n";
return str.str();
}
template <typename SummandsType>
ValueType getWeightedSum(BeliefId const& beliefId, SummandsType const& summands) {
ValueType result = storm::utility::zero<ValueType>();
for (auto const& entry : getBelief(beliefId)) {
result += storm::utility::convertNumber<ValueType>(entry.second) * storm::utility::convertNumber<ValueType>(summands.at(entry.first));
}
return result;
}
BeliefId const& getInitialBelief() const {
return initialBeliefId;
}
ValueType getBeliefActionReward(BeliefId const& beliefId, uint64_t const& localActionIndex) const {
auto const& belief = getBelief(beliefId);
STORM_LOG_ASSERT(!pomdpActionRewardVector.empty(), "Requested a reward although no reward model was specified.");
auto result = storm::utility::zero<ValueType>();
auto const& choiceIndices = pomdp.getTransitionMatrix().getRowGroupIndices();
for (auto const &entry : belief) {
uint64_t choiceIndex = choiceIndices[entry.first] + localActionIndex;
STORM_LOG_ASSERT(choiceIndex < choiceIndices[entry.first + 1], "Invalid local action index.");
STORM_LOG_ASSERT(choiceIndex < pomdpActionRewardVector.size(), "Invalid choice index.");
result += entry.second * pomdpActionRewardVector[choiceIndex];
}
return result;
}
uint32_t getBeliefObservation(BeliefId beliefId) {
return getBeliefObservation(getBelief(beliefId));
}
uint64_t getBeliefNumberOfChoices(BeliefId beliefId) {
auto belief = getBelief(beliefId);
return pomdp.getNumberOfChoices(belief.begin()->first);
}
Triangulation triangulateBelief(BeliefId beliefId, uint64_t resolution) {
return triangulateBelief(getBelief(beliefId), resolution);
}
template<typename DistributionType>
void addToDistribution(DistributionType& distr, StateType const& state, BeliefValueType const& value) {
auto insertionRes = distr.emplace(state, value);
if (!insertionRes.second) {
insertionRes.first->second += value;
}
}
BeliefId getNumberOfBeliefIds() const {
return beliefs.size();
}
std::map<BeliefId, ValueType> expandAndTriangulate(BeliefId const& beliefId, uint64_t actionIndex, std::vector<uint64_t> const& observationResolutions) {
return expandInternal(beliefId, actionIndex, observationResolutions);
}
std::map<BeliefId, ValueType> expand(BeliefId const& beliefId, uint64_t actionIndex) {
return expandInternal(beliefId, actionIndex);
}
private:
BeliefType const& getBelief(BeliefId const& id) const { BeliefType const& getBelief(BeliefId const& id) const {
STORM_LOG_ASSERT(id != noId(), "Tried to get a non-existend belief."); STORM_LOG_ASSERT(id != noId(), "Tried to get a non-existend belief.");
STORM_LOG_ASSERT(id < getNumberOfBeliefIds(), "Belief index " << id << " is out of range."); STORM_LOG_ASSERT(id < getNumberOfBeliefIds(), "Belief index " << id << " is out of range.");
@ -56,10 +142,6 @@ namespace storm {
return idIt->second; return idIt->second;
} }
BeliefId noId() const {
return std::numeric_limits<BeliefId>::max();
}
std::string toString(BeliefType const& belief) const { std::string toString(BeliefType const& belief) const {
std::stringstream str; std::stringstream str;
str << "{ "; str << "{ ";
@ -76,16 +158,6 @@ namespace storm {
return str.str(); return str.str();
} }
std::string toString(Triangulation const& t) const {
std::stringstream str;
str << "(\n";
for (uint64_t i = 0; i < t.size(); ++i) {
str << "\t" << t.weights[i] << " * \t" << toString(getBelief(t.gridPoints[i])) << "\n";
}
str <<")\n";
return str.str();
}
bool isEqual(BeliefType const& first, BeliefType const& second) const { bool isEqual(BeliefType const& first, BeliefType const& second) const {
if (first.size() != second.size()) { if (first.size() != second.size()) {
return false; return false;
@ -186,49 +258,11 @@ namespace storm {
return true; return true;
} }
template <typename SummandsType>
ValueType getWeightedSum(BeliefId const& beliefId, SummandsType const& summands) {
ValueType result = storm::utility::zero<ValueType>();
for (auto const& entry : getBelief(beliefId)) {
result += storm::utility::convertNumber<ValueType>(entry.second) * storm::utility::convertNumber<ValueType>(summands.at(entry.first));
}
return result;
}
BeliefId const& getInitialBelief() const {
return initialBeliefId;
}
ValueType getBeliefActionReward(BeliefId const& beliefId, uint64_t const& localActionIndex) const {
auto const& belief = getBelief(beliefId);
STORM_LOG_ASSERT(!pomdpActionRewardVector.empty(), "Requested a reward although no reward model was specified.");
auto result = storm::utility::zero<ValueType>();
auto const& choiceIndices = pomdp.getTransitionMatrix().getRowGroupIndices();
for (auto const &entry : belief) {
uint64_t choiceIndex = choiceIndices[entry.first] + localActionIndex;
STORM_LOG_ASSERT(choiceIndex < choiceIndices[entry.first + 1], "Invalid local action index.");
STORM_LOG_ASSERT(choiceIndex < pomdpActionRewardVector.size(), "Invalid choice index.");
result += entry.second * pomdpActionRewardVector[choiceIndex];
}
return result;
}
uint32_t getBeliefObservation(BeliefType belief) { uint32_t getBeliefObservation(BeliefType belief) {
STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief."); STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief.");
return pomdp.getObservation(belief.begin()->first); return pomdp.getObservation(belief.begin()->first);
} }
uint32_t getBeliefObservation(BeliefId beliefId) {
return getBeliefObservation(getBelief(beliefId));
}
uint64_t getBeliefNumberOfChoices(BeliefId beliefId) {
auto belief = getBelief(beliefId);
return pomdp.getNumberOfChoices(belief.begin()->first);
}
Triangulation triangulateBelief(BeliefType belief, uint64_t resolution) { Triangulation triangulateBelief(BeliefType belief, uint64_t resolution) {
//TODO this can also be simplified using the sparse vector interpretation //TODO this can also be simplified using the sparse vector interpretation
//TODO Enable chaching for this method? //TODO Enable chaching for this method?
@ -307,22 +341,6 @@ namespace storm {
return result; return result;
} }
Triangulation triangulateBelief(BeliefId beliefId, uint64_t resolution) {
return triangulateBelief(getBelief(beliefId), resolution);
}
template<typename DistributionType>
void addToDistribution(DistributionType& distr, StateType const& state, BeliefValueType const& value) {
auto insertionRes = distr.emplace(state, value);
if (!insertionRes.second) {
insertionRes.first->second += value;
}
}
BeliefId getNumberOfBeliefIds() const {
return beliefs.size();
}
std::map<BeliefId, ValueType> expandInternal(BeliefId const& beliefId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) { std::map<BeliefId, ValueType> expandInternal(BeliefId const& beliefId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) {
std::map<BeliefId, ValueType> destinations; std::map<BeliefId, ValueType> destinations;
// TODO: Output as vector? // TODO: Output as vector?
@ -369,16 +387,6 @@ namespace storm {
} }
std::map<BeliefId, ValueType> expandAndTriangulate(BeliefId const& beliefId, uint64_t actionIndex, std::vector<uint64_t> const& observationResolutions) {
return expandInternal(beliefId, actionIndex, observationResolutions);
}
std::map<BeliefId, ValueType> expand(BeliefId const& beliefId, uint64_t actionIndex) {
return expandInternal(beliefId, actionIndex);
}
private:
BeliefId computeInitialBelief() { BeliefId computeInitialBelief() {
STORM_LOG_ASSERT(pomdp.getInitialStates().getNumberOfSetBits() < 2, STORM_LOG_ASSERT(pomdp.getInitialStates().getNumberOfSetBits() < 2,
"POMDP contains more than one initial state"); "POMDP contains more than one initial state");

Loading…
Cancel
Save