354 lines
17 KiB

#pragma once
#include <map>
#include <boost/optional.hpp>
//#include <boost/container/flat_map.hpp>
#include "storm/utility/macros.h"
#include "storm/exceptions/UnexpectedException.h"
namespace storm {
namespace storage {
template <typename PomdpType, typename BeliefValueType = typename PomdpType::ValueType, typename StateType = uint64_t>
// TODO: Change name. This actually does not only manage grid points.
class BeliefGrid {
public:
typedef typename PomdpType::ValueType ValueType;
//typedef boost::container::flat_map<StateType, BeliefValueType> BeliefType
typedef std::map<StateType, BeliefValueType> BeliefType;
typedef uint64_t BeliefId;
BeliefGrid(PomdpType const& pomdp, BeliefValueType const& precision) : pomdp(pomdp), cc(precision, false) {
// Intentionally left empty
}
struct Triangulation {
std::vector<BeliefId> gridPoints;
std::vector<BeliefValueType> weights;
uint64_t size() const {
return weights.size();
}
};
BeliefType const& getGridPoint(BeliefId const& id) const {
return gridPoints[id];
}
BeliefId getIdOfGridPoint(BeliefType const& gridPoint) const {
auto idIt = gridPointToIdMap.find(gridPoint);
STORM_LOG_THROW(idIt != gridPointToIdMap.end(), storm::exceptions::UnexpectedException, "Unknown grid state.");
return idIt->second;
}
std::string toString(BeliefType const& belief) const {
std::stringstream str;
str << "{ ";
bool first = true;
for (auto const& entry : belief) {
if (first) {
first = false;
} else {
str << ", ";
}
str << entry.first << ": " << entry.second;
}
str << " }";
return str.str();
}
std::string toString(Triangulation const& t) const {
std::stringstream str;
str << "(\n";
for (uint64_t i = 0; i < t.size(); ++i) {
str << "\t" << t.weights[i] << " * \t" << toString(getGridPoint(t.gridPoints[i])) << "\n";
}
str <<")\n";
return str.str();
}
bool isEqual(BeliefType const& first, BeliefType const& second) const {
if (first.size() != second.size()) {
return false;
}
auto secondIt = second.begin();
for (auto const& firstEntry : first) {
if (firstEntry.first != secondIt->first) {
return false;
}
if (!cc.isEqual(firstEntry.second, secondIt->second)) {
return false;
}
++secondIt;
}
return true;
}
bool assertBelief(BeliefType const& belief) const {
BeliefValueType sum = storm::utility::zero<ValueType>();
boost::optional<uint32_t> observation;
for (auto const& entry : belief) {
uintmax_t entryObservation = pomdp.getObservation(entry.first);
if (observation) {
if (observation.get() != entryObservation) {
STORM_LOG_ERROR("Beliefsupport contains different observations.");
return false;
}
} else {
observation = entryObservation;
}
if (cc.isZero(entry.second)) {
// We assume that beliefs only consider their support.
STORM_LOG_ERROR("Zero belief probability.");
return false;
}
if (cc.isLess(entry.second, storm::utility::zero<BeliefValueType>())) {
STORM_LOG_ERROR("Negative belief probability.");
return false;
}
if (cc.isLess(storm::utility::one<BeliefValueType>(), entry.second)) {
STORM_LOG_ERROR("Belief probability greater than one.");
return false;
}
sum += entry.second;
}
if (!cc.isOne(sum)) {
STORM_LOG_ERROR("Belief does not sum up to one.");
return false;
}
return true;
}
bool assertTriangulation(BeliefType const& belief, Triangulation const& triangulation) const {
if (triangulation.weights.size() != triangulation.gridPoints.size()) {
STORM_LOG_ERROR("Number of weights and points in triangulation does not match.");
return false;
}
if (triangulation.size() == 0) {
STORM_LOG_ERROR("Empty triangulation.");
return false;
}
BeliefType triangulatedBelief;
BeliefValueType weightSum = storm::utility::zero<BeliefValueType>();
for (uint64_t i = 0; i < triangulation.weights.size(); ++i) {
if (cc.isZero(triangulation.weights[i])) {
STORM_LOG_ERROR("Zero weight in triangulation.");
return false;
}
if (cc.isLess(triangulation.weights[i], storm::utility::zero<BeliefValueType>())) {
STORM_LOG_ERROR("Negative weight in triangulation.");
return false;
}
if (cc.isLess(storm::utility::one<BeliefValueType>(), triangulation.weights[i])) {
STORM_LOG_ERROR("Weight greater than one in triangulation.");
}
weightSum += triangulation.weights[i];
BeliefType const& gridPoint = getGridPoint(triangulation.gridPoints[i]);
for (auto const& pointEntry : gridPoint) {
BeliefValueType& triangulatedValue = triangulatedBelief.emplace(pointEntry.first, storm::utility::zero<ValueType>()).first->second;
triangulatedValue += triangulation.weights[i] * pointEntry.second;
}
}
if (!cc.isOne(weightSum)) {
STORM_LOG_ERROR("Triangulation weights do not sum up to one.");
return false;
}
if (!assertBelief(triangulatedBelief)) {
STORM_LOG_ERROR("Triangulated belief is not a belief.");
}
if (!isEqual(belief, triangulatedBelief)) {
STORM_LOG_ERROR("Belief:\n\t" << toString(belief) << "\ndoes not match triangulated belief:\n\t" << toString(triangulatedBelief) << ".");
return false;
}
return true;
}
BeliefId getInitialBelief() {
STORM_LOG_ASSERT(pomdp.getInitialStates().getNumberOfSetBits() < 2,
"POMDP contains more than one initial state");
STORM_LOG_ASSERT(pomdp.getInitialStates().getNumberOfSetBits() == 1,
"POMDP does not contain an initial state");
BeliefType belief;
belief[*pomdp.getInitialStates().begin()] = storm::utility::one<ValueType>();
STORM_LOG_ASSERT(assertBelief(belief), "Invalid initial belief.");
return getOrAddGridPointId(belief);
}
uint32_t getBeliefObservation(BeliefType belief) {
STORM_LOG_ASSERT(assertBelief(belief), "Invalid belief.");
return pomdp.getObservation(belief.begin()->first);
}
uint32_t getBeliefObservation(BeliefId beliefId) {
return getBeliefObservation(getGridPoint(beliefId));
}
Triangulation triangulateBelief(BeliefType belief, uint64_t resolution) {
//TODO this can also be simplified using the sparse vector interpretation
//TODO Enable chaching for this method?
STORM_LOG_ASSERT(assertBelief(belief), "Input belief for triangulation is not valid.");
auto nrStates = pomdp.getNumberOfStates();
// This is the Freudenthal Triangulation as described in Lovejoy (a whole lotta math)
// Variable names are based on the paper
// TODO avoid reallocations for these vectors
std::vector<BeliefValueType> x(nrStates);
std::vector<BeliefValueType> v(nrStates);
std::vector<BeliefValueType> d(nrStates);
auto convResolution = storm::utility::convertNumber<BeliefValueType>(resolution);
for (size_t i = 0; i < nrStates; ++i) {
for (auto const &probEntry : belief) {
if (probEntry.first >= i) {
x[i] += convResolution * probEntry.second;
}
}
v[i] = storm::utility::floor(x[i]);
d[i] = x[i] - v[i];
}
auto p = storm::utility::vector::getSortedIndices(d);
std::vector<std::vector<BeliefValueType>> qs(nrStates, std::vector<BeliefValueType>(nrStates));
for (size_t i = 0; i < nrStates; ++i) {
if (i == 0) {
for (size_t j = 0; j < nrStates; ++j) {
qs[i][j] = v[j];
}
} else {
for (size_t j = 0; j < nrStates; ++j) {
if (j == p[i - 1]) {
qs[i][j] = qs[i - 1][j] + storm::utility::one<BeliefValueType>();
} else {
qs[i][j] = qs[i - 1][j];
}
}
}
}
Triangulation result;
// The first weight is 1-sum(other weights). We therefore process the js in reverse order
BeliefValueType firstWeight = storm::utility::one<BeliefValueType>();
for (size_t j = nrStates; j > 0;) {
--j;
// First create the weights. The weights vector will be reversed at the end.
ValueType weight;
if (j == 0) {
weight = firstWeight;
} else {
weight = d[p[j - 1]] - d[p[j]];
firstWeight -= weight;
}
if (!cc.isZero(weight)) {
result.weights.push_back(weight);
BeliefType gridPoint;
auto const& qsj = qs[j];
for (size_t i = 0; i < nrStates - 1; ++i) {
BeliefValueType gridPointEntry = qsj[i] - qsj[i + 1];
if (!cc.isZero(gridPointEntry)) {
gridPoint[i] = gridPointEntry / convResolution;
}
}
if (!cc.isZero(qsj[nrStates - 1])) {
gridPoint[nrStates - 1] = qsj[nrStates - 1] / convResolution;
}
result.gridPoints.push_back(getOrAddGridPointId(gridPoint));
}
}
STORM_LOG_ASSERT(assertTriangulation(belief, result), "Incorrect triangulation: " << toString(result));
return result;
}
template<typename DistributionType>
void addToDistribution(DistributionType& distr, StateType const& state, BeliefValueType const& value) {
auto insertionRes = distr.emplace(state, value);
if (!insertionRes.second) {
insertionRes.first->second += value;
}
}
BeliefId getNumberOfGridPointIds() const {
return gridPoints.size();
}
std::map<BeliefId, ValueType> expandInternal(BeliefId const& gridPointId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) {
std::map<BeliefId, ValueType> destinations; // The belief ids should be ordered
// TODO: Does this make sense? It could be better to order them afterwards because now we rely on the fact that MDP states have the same order than their associated BeliefIds
BeliefType gridPoint = getGridPoint(gridPointId);
// Find the probability we go to each observation
BeliefType successorObs; // This is actually not a belief but has the same type
for (auto const& pointEntry : gridPoint) {
uint64_t state = pointEntry.first;
for (auto const& pomdpTransition : pomdp.getTransitionMatrix().getRow(state, actionIndex)) {
if (!storm::utility::isZero(pomdpTransition.getValue())) {
auto obs = pomdp.getObservation(pomdpTransition.getColumn());
addToDistribution(successorObs, obs, pointEntry.second * pomdpTransition.getValue());
}
}
}
// Now for each successor observation we find and potentially triangulate the successor belief
for (auto const& successor : successorObs) {
BeliefType successorBelief;
for (auto const& pointEntry : gridPoint) {
uint64_t state = pointEntry.first;
for (auto const& pomdpTransition : pomdp.getTransitionMatrix().getRow(state, actionIndex)) {
if (pomdp.getObservation(pomdpTransition.getColumn()) == successor.first) {
ValueType prob = pointEntry.second * pomdpTransition.getValue() / successor.second;
addToDistribution(successorBelief, pomdpTransition.getColumn(), prob);
}
}
}
STORM_LOG_ASSERT(assertBelief(successorBelief), "Invalid successor belief.");
if (observationTriangulationResolutions) {
Triangulation triangulation = triangulateBelief(successorBelief, observationTriangulationResolutions.get()[successor.first]);
for (size_t j = 0; j < triangulation.size(); ++j) {
addToDistribution(destinations, triangulation.gridPoints[j], triangulation.weights[j] * successor.second);
}
} else {
addToDistribution(destinations, getOrAddGridPointId(successorBelief), successor.second);
}
}
return destinations;
}
std::map<BeliefId, ValueType> expandAndTriangulate(BeliefId const& gridPointId, uint64_t actionIndex, std::vector<uint64_t> const& observationResolutions) {
return expandInternal(gridPointId, actionIndex, observationResolutions);
}
std::map<BeliefId, ValueType> expand(BeliefId const& gridPointId, uint64_t actionIndex) {
return expandInternal(gridPointId, actionIndex);
}
private:
BeliefId getOrAddGridPointId(BeliefType const& gridPoint) {
auto insertioRes = gridPointToIdMap.emplace(gridPoint, gridPoints.size());
if (insertioRes.second) {
// There actually was an insertion, so add the new grid state
gridPoints.push_back(gridPoint);
}
// Return the id
return insertioRes.first->second;
}
PomdpType const& pomdp;
std::vector<BeliefType> gridPoints;
std::map<BeliefType, BeliefId> gridPointToIdMap;
storm::utility::ConstantsComparator<ValueType> cc;
};
}
}