From 2f020ce6860af6cb45a460aca32e1d52f910df74 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Wed, 8 Apr 2020 08:12:32 +0200 Subject: [PATCH] BeliefManager: Making Freudenthal happy (and fast) --- src/storm-pomdp/storage/BeliefManager.h | 122 ++++++++++-------------- 1 file changed, 53 insertions(+), 69 deletions(-) diff --git a/src/storm-pomdp/storage/BeliefManager.h b/src/storm-pomdp/storage/BeliefManager.h index 7e37b9c16..95c596005 100644 --- a/src/storm-pomdp/storage/BeliefManager.h +++ b/src/storm-pomdp/storage/BeliefManager.h @@ -263,96 +263,80 @@ namespace storm { return pomdp.getObservation(belief.begin()->first); } - struct FreudenthalData { - FreudenthalData(StateType const& pomdpState, StateType const& dimension, BeliefValueType const& x) : pomdpState(pomdpState), dimension(dimension), value(storm::utility::floor(x)), diff(x-value) { }; - StateType pomdpState; + struct FreudenthalDiff { + FreudenthalDiff(StateType const& dimension, BeliefValueType&& diff) : dimension(dimension), diff(std::move(diff)) { }; StateType dimension; // i - BeliefValueType value; // v[i] in the Lovejoy paper - BeliefValueType diff; // d[i] in the Lovejoy paper - }; - struct FreudenthalDataComparator { - bool operator()(FreudenthalData const& first, FreudenthalData const& second) const { - if (first.diff != second.diff) { - return first.diff > second.diff; + BeliefValueType diff; // d[i] + bool operator>(FreudenthalDiff const& other) const { + if (diff != other.diff) { + return diff > other.diff; } else { - return first.dimension < second.dimension; + return dimension < other.dimension; } } }; Triangulation triangulateBelief(BeliefType belief, uint64_t resolution) { - //TODO Enable chaching for this method? STORM_LOG_ASSERT(assertBelief(belief), "Input belief for triangulation is not valid."); - - auto convResolution = storm::utility::convertNumber(resolution); - - // This is the Freudenthal Triangulation as described in Lovejoy (a whole lotta math) - // Variable names are based on the paper - // However, we speed this up a little by exploiting that belief states usually have sparse support. - // TODO: for the sorting, it probably suffices to have a map from diffs to dimensions. The other Freudenthaldata could then also be stored in vectors, which would be a bit more like the original algorithm - - // Initialize some data - std::vector::iterator> dataIterators; - dataIterators.reserve(belief.size()); - // Initialize first row of 'qs' matrix - std::vector qsRow; - qsRow.reserve(dataIterators.size()); - std::set freudenthalData; - BeliefValueType x = convResolution; - for (auto const& entry : belief) { - auto insertionIt = freudenthalData.emplace(entry.first, dataIterators.size(), x).first; - dataIterators.push_back(insertionIt); - qsRow.push_back(dataIterators.back()->value); - x -= entry.second * convResolution; - } - qsRow.push_back(storm::utility::zero()); - assert(!freudenthalData.empty()); - + StateType numEntries = belief.size(); Triangulation result; - result.weights.reserve(freudenthalData.size()); - result.gridPoints.reserve(freudenthalData.size()); - - // Insert first grid point - // TODO: this special treatment is actually not necessary. - BeliefValueType firstWeight = storm::utility::one() - freudenthalData.begin()->diff + freudenthalData.rbegin()->diff; - if (!cc.isZero(firstWeight)) { - result.weights.push_back(firstWeight); - BeliefType gridPoint; - for (StateType j = 0; j < dataIterators.size(); ++j) { - BeliefValueType gridPointEntry = qsRow[j] - qsRow[j + 1]; - if (!cc.isZero(gridPointEntry)) { - gridPoint[dataIterators[j]->pomdpState] = gridPointEntry / convResolution; - } + + // Quickly triangulate Dirac beliefs + if (numEntries == 1u) { + result.weights.push_back(storm::utility::one()); + result.gridPoints.push_back(getOrAddBeliefId(belief)); + } else { + + auto convResolution = storm::utility::convertNumber(resolution); + // This is the Freudenthal Triangulation as described in Lovejoy (a whole lotta math) + // Variable names are mostly based on the paper + // However, we speed this up a little by exploiting that belief states usually have sparse support (i.e. numEntries is much smaller than pomdp.getNumberOfStates()). + // Initialize diffs and the first row of the 'qs' matrix (aka v) + std::set> sorted_diffs; // d (and p?) in the paper + std::vector qsRow; // Row of the 'qs' matrix from the paper (initially corresponds to v + qsRow.reserve(numEntries); + std::vector toOriginalIndicesMap; // Maps 'local' indices to the original pomdp state indices + toOriginalIndicesMap.reserve(numEntries); + BeliefValueType x = convResolution; + for (auto const& entry : belief) { + qsRow.push_back(storm::utility::floor(x)); // v + sorted_diffs.emplace(toOriginalIndicesMap.size(), x - qsRow.back()); // x-v + toOriginalIndicesMap.push_back(entry.first); + x -= entry.second * convResolution; } - result.gridPoints.push_back(getOrAddBeliefId(gridPoint)); - } - - if (freudenthalData.size() > 1) { - // Insert remaining grid points - auto currentSortedEntry = freudenthalData.begin(); - auto previousSortedEntry = currentSortedEntry++; - for (StateType i = 1; i < dataIterators.size(); ++i) { - // 'compute' the next row of the qs matrix - qsRow[previousSortedEntry->dimension] += storm::utility::one(); - - BeliefValueType weight = previousSortedEntry->diff - currentSortedEntry->diff; + // Insert a dummy 0 column in the qs matrix so the loops below are a bit simpler + qsRow.push_back(storm::utility::zero()); + + result.weights.reserve(numEntries); + result.gridPoints.reserve(numEntries); + auto currentSortedDiff = sorted_diffs.begin(); + auto previousSortedDiff = sorted_diffs.end(); + --previousSortedDiff; + for (StateType i = 0; i < numEntries; ++i) { + // Compute the weight for the grid points + BeliefValueType weight = previousSortedDiff->diff - currentSortedDiff->diff; + if (i == 0) { + // The first weight is a bit different + weight += storm::utility::one(); + } else { + // 'compute' the next row of the qs matrix + qsRow[previousSortedDiff->dimension] += storm::utility::one(); + } if (!cc.isZero(weight)) { result.weights.push_back(weight); - + // Compute the grid point BeliefType gridPoint; - for (StateType j = 0; j < dataIterators.size(); ++j) { + for (StateType j = 0; j < numEntries; ++j) { BeliefValueType gridPointEntry = qsRow[j] - qsRow[j + 1]; if (!cc.isZero(gridPointEntry)) { - gridPoint[dataIterators[j]->pomdpState] = gridPointEntry / convResolution; + gridPoint[toOriginalIndicesMap[j]] = gridPointEntry / convResolution; } } result.gridPoints.push_back(getOrAddBeliefId(gridPoint)); } - ++previousSortedEntry; - ++currentSortedEntry; + previousSortedDiff = currentSortedDiff++; } } - STORM_LOG_ASSERT(assertTriangulation(belief, result), "Incorrect triangulation: " << toString(result)); return result; }