From b901a1b308fb08a3d491bb4fae7da8170f99acf5 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sat, 13 Jun 2020 18:09:22 -0700 Subject: [PATCH] fix for the binary pomdp transformer: labels are now attached to auxiliary states as well --- .../transformer/BinaryPomdpTransformer.cpp | 19 ++++++++++++++----- .../transformer/BinaryPomdpTransformer.h | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp b/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp index 6d85f8cad..385e415f1 100644 --- a/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp +++ b/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp @@ -29,10 +29,11 @@ namespace storm { } struct BinaryPomdpTransformerRowGroup { - BinaryPomdpTransformerRowGroup(uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) { + BinaryPomdpTransformerRowGroup(uint64_t origState, uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : origState(origState), firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) { // Intentionally left empty. } - + + uint64_t origState; uint64_t firstRow; uint64_t endRow; uint32_t origStateObservation; @@ -46,11 +47,11 @@ namespace storm { assert(size() > 1); uint64_t midRow = firstRow + size()/2; std::vector res; - res.emplace_back(firstRow, midRow, origStateObservation); + res.emplace_back(origState, firstRow, midRow, origStateObservation); storm::storage::BitVector newAuxStateId = auxStateId; newAuxStateId.resize(auxStateId.size() + 1, false); res.back().auxStateId = newAuxStateId; - res.emplace_back(midRow, endRow, origStateObservation); + res.emplace_back(origState, midRow, endRow, origStateObservation); newAuxStateId.set(auxStateId.size(), true); res.back().auxStateId = newAuxStateId; return res; @@ -75,7 +76,7 @@ namespace storm { // Initialize a FIFO Queue that stores the start and the end of each row group std::queue queue; for (uint64_t state = 0; state < matrix.getRowGroupCount(); ++state) { - queue.emplace(matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state)); + queue.emplace(state, matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state)); } std::vector newObservations; @@ -84,6 +85,8 @@ namespace storm { uint64_t currRow = 0; std::vector origRowToSimpleRowMap(pomdp.getNumberOfChoices(), std::numeric_limits::max()); uint64_t currAuxState = queue.size(); + std::vector origStates; + while (!queue.empty()) { auto group = std::move(queue.front()); queue.pop(); @@ -120,12 +123,14 @@ namespace storm { } } // Nothing to be done if group has size zero + origStates.push_back(group.origState); } TransformationData result; result.simpleMatrix = builder.build(currRow, currAuxState, currAuxState); result.simpleObservations = std::move(newObservations); result.originalToSimpleChoiceMap = std::move(origRowToSimpleRowMap); + result.simpleStateToOriginalState = std::move(origStates); return result; } @@ -136,7 +141,11 @@ namespace storm { for (auto const& labelName : pomdp.getStateLabeling().getLabels()) { storm::storage::BitVector newStates = pomdp.getStateLabeling().getStates(labelName); newStates.resize(data.simpleMatrix.getRowGroupCount(), false); + for (uint64_t newState = pomdp.getNumberOfStates(); newState < data.simpleMatrix.getRowGroupCount(); ++newState ) { + newStates.set(newState, newStates[data.simpleStateToOriginalState[newState]]); + } labeling.addLabel(labelName, std::move(newStates)); + } return labeling; } diff --git a/src/storm-pomdp/transformer/BinaryPomdpTransformer.h b/src/storm-pomdp/transformer/BinaryPomdpTransformer.h index 734592ead..1e5fae76f 100644 --- a/src/storm-pomdp/transformer/BinaryPomdpTransformer.h +++ b/src/storm-pomdp/transformer/BinaryPomdpTransformer.h @@ -21,6 +21,7 @@ namespace storm { storm::storage::SparseMatrix simpleMatrix; std::vector simpleObservations; std::vector originalToSimpleChoiceMap; + std::vector simpleStateToOriginalState; }; TransformationData transformTransitions(storm::models::sparse::Pomdp const& pomdp, bool transformSimple) const;