Browse Source

BeliefManager: Making Freudenthal happy (and fast)

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
2f020ce686
  1. 122
      src/storm-pomdp/storage/BeliefManager.h

122
src/storm-pomdp/storage/BeliefManager.h

@ -263,96 +263,80 @@ namespace storm {
return pomdp.getObservation(belief.begin()->first);
}
struct FreudenthalData {
FreudenthalData(StateType const& pomdpState, StateType const& dimension, BeliefValueType const& x) : pomdpState(pomdpState), dimension(dimension), value(storm::utility::floor(x)), diff(x-value) { };
StateType pomdpState;
struct FreudenthalDiff {
FreudenthalDiff(StateType const& dimension, BeliefValueType&& diff) : dimension(dimension), diff(std::move(diff)) { };
StateType dimension; // i
BeliefValueType value; // v[i] in the Lovejoy paper
BeliefValueType diff; // d[i] in the Lovejoy paper
};
struct FreudenthalDataComparator {
bool operator()(FreudenthalData const& first, FreudenthalData const& second) const {
if (first.diff != second.diff) {
return first.diff > second.diff;
BeliefValueType diff; // d[i]
bool operator>(FreudenthalDiff const& other) const {
if (diff != other.diff) {
return diff > other.diff;
} else {
return first.dimension < second.dimension;
return dimension < other.dimension;
}
}
};
Triangulation triangulateBelief(BeliefType belief, uint64_t resolution) {
//TODO Enable chaching for this method?
STORM_LOG_ASSERT(assertBelief(belief), "Input belief for triangulation is not valid.");
auto convResolution = storm::utility::convertNumber<BeliefValueType>(resolution);
// This is the Freudenthal Triangulation as described in Lovejoy (a whole lotta math)
// Variable names are based on the paper
// However, we speed this up a little by exploiting that belief states usually have sparse support.
// TODO: for the sorting, it probably suffices to have a map from diffs to dimensions. The other Freudenthaldata could then also be stored in vectors, which would be a bit more like the original algorithm
// Initialize some data
std::vector<typename std::set<FreudenthalData, FreudenthalDataComparator>::iterator> dataIterators;
dataIterators.reserve(belief.size());
// Initialize first row of 'qs' matrix
std::vector<BeliefValueType> qsRow;
qsRow.reserve(dataIterators.size());
std::set<FreudenthalData, FreudenthalDataComparator> freudenthalData;
BeliefValueType x = convResolution;
for (auto const& entry : belief) {
auto insertionIt = freudenthalData.emplace(entry.first, dataIterators.size(), x).first;
dataIterators.push_back(insertionIt);
qsRow.push_back(dataIterators.back()->value);
x -= entry.second * convResolution;
}
qsRow.push_back(storm::utility::zero<BeliefValueType>());
assert(!freudenthalData.empty());
StateType numEntries = belief.size();
Triangulation result;
result.weights.reserve(freudenthalData.size());
result.gridPoints.reserve(freudenthalData.size());
// Insert first grid point
// TODO: this special treatment is actually not necessary.
BeliefValueType firstWeight = storm::utility::one<ValueType>() - freudenthalData.begin()->diff + freudenthalData.rbegin()->diff;
if (!cc.isZero(firstWeight)) {
result.weights.push_back(firstWeight);
BeliefType gridPoint;
for (StateType j = 0; j < dataIterators.size(); ++j) {
BeliefValueType gridPointEntry = qsRow[j] - qsRow[j + 1];
if (!cc.isZero(gridPointEntry)) {
gridPoint[dataIterators[j]->pomdpState] = gridPointEntry / convResolution;
}
// Quickly triangulate Dirac beliefs
if (numEntries == 1u) {
result.weights.push_back(storm::utility::one<BeliefValueType>());
result.gridPoints.push_back(getOrAddBeliefId(belief));
} else {
auto convResolution = storm::utility::convertNumber<BeliefValueType>(resolution);
// This is the Freudenthal Triangulation as described in Lovejoy (a whole lotta math)
// Variable names are mostly based on the paper
// However, we speed this up a little by exploiting that belief states usually have sparse support (i.e. numEntries is much smaller than pomdp.getNumberOfStates()).
// Initialize diffs and the first row of the 'qs' matrix (aka v)
std::set<FreudenthalDiff, std::greater<FreudenthalDiff>> sorted_diffs; // d (and p?) in the paper
std::vector<BeliefValueType> qsRow; // Row of the 'qs' matrix from the paper (initially corresponds to v
qsRow.reserve(numEntries);
std::vector<StateType> toOriginalIndicesMap; // Maps 'local' indices to the original pomdp state indices
toOriginalIndicesMap.reserve(numEntries);
BeliefValueType x = convResolution;
for (auto const& entry : belief) {
qsRow.push_back(storm::utility::floor(x)); // v
sorted_diffs.emplace(toOriginalIndicesMap.size(), x - qsRow.back()); // x-v
toOriginalIndicesMap.push_back(entry.first);
x -= entry.second * convResolution;
}
result.gridPoints.push_back(getOrAddBeliefId(gridPoint));
}
if (freudenthalData.size() > 1) {
// Insert remaining grid points
auto currentSortedEntry = freudenthalData.begin();
auto previousSortedEntry = currentSortedEntry++;
for (StateType i = 1; i < dataIterators.size(); ++i) {
// 'compute' the next row of the qs matrix
qsRow[previousSortedEntry->dimension] += storm::utility::one<BeliefValueType>();
BeliefValueType weight = previousSortedEntry->diff - currentSortedEntry->diff;
// Insert a dummy 0 column in the qs matrix so the loops below are a bit simpler
qsRow.push_back(storm::utility::zero<BeliefValueType>());
result.weights.reserve(numEntries);
result.gridPoints.reserve(numEntries);
auto currentSortedDiff = sorted_diffs.begin();
auto previousSortedDiff = sorted_diffs.end();
--previousSortedDiff;
for (StateType i = 0; i < numEntries; ++i) {
// Compute the weight for the grid points
BeliefValueType weight = previousSortedDiff->diff - currentSortedDiff->diff;
if (i == 0) {
// The first weight is a bit different
weight += storm::utility::one<ValueType>();
} else {
// 'compute' the next row of the qs matrix
qsRow[previousSortedDiff->dimension] += storm::utility::one<BeliefValueType>();
}
if (!cc.isZero(weight)) {
result.weights.push_back(weight);
// Compute the grid point
BeliefType gridPoint;
for (StateType j = 0; j < dataIterators.size(); ++j) {
for (StateType j = 0; j < numEntries; ++j) {
BeliefValueType gridPointEntry = qsRow[j] - qsRow[j + 1];
if (!cc.isZero(gridPointEntry)) {
gridPoint[dataIterators[j]->pomdpState] = gridPointEntry / convResolution;
gridPoint[toOriginalIndicesMap[j]] = gridPointEntry / convResolution;
}
}
result.gridPoints.push_back(getOrAddBeliefId(gridPoint));
}
++previousSortedEntry;
++currentSortedEntry;
previousSortedDiff = currentSortedDiff++;
}
}
STORM_LOG_ASSERT(assertTriangulation(belief, result), "Incorrect triangulation: " << toString(result));
return result;
}

Loading…
Cancel
Save