Browse Source

fix for the binary pomdp transformer: labels are now attached to auxiliary states as well

main
Sebastian Junges 5 years ago
parent
commit
b901a1b308
  1. 19
      src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp
  2. 1
      src/storm-pomdp/transformer/BinaryPomdpTransformer.h

19
src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp

@ -29,10 +29,11 @@ namespace storm {
} }
struct BinaryPomdpTransformerRowGroup { struct BinaryPomdpTransformerRowGroup {
BinaryPomdpTransformerRowGroup(uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) { BinaryPomdpTransformerRowGroup(uint64_t origState, uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : origState(origState), firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) {
// Intentionally left empty. // Intentionally left empty.
} }
uint64_t origState;
uint64_t firstRow; uint64_t firstRow;
uint64_t endRow; uint64_t endRow;
uint32_t origStateObservation; uint32_t origStateObservation;
@ -46,11 +47,11 @@ namespace storm {
assert(size() > 1); assert(size() > 1);
uint64_t midRow = firstRow + size()/2; uint64_t midRow = firstRow + size()/2;
std::vector<BinaryPomdpTransformerRowGroup> res; std::vector<BinaryPomdpTransformerRowGroup> res;
res.emplace_back(firstRow, midRow, origStateObservation); res.emplace_back(origState, firstRow, midRow, origStateObservation);
storm::storage::BitVector newAuxStateId = auxStateId; storm::storage::BitVector newAuxStateId = auxStateId;
newAuxStateId.resize(auxStateId.size() + 1, false); newAuxStateId.resize(auxStateId.size() + 1, false);
res.back().auxStateId = newAuxStateId; res.back().auxStateId = newAuxStateId;
res.emplace_back(midRow, endRow, origStateObservation); res.emplace_back(origState, midRow, endRow, origStateObservation);
newAuxStateId.set(auxStateId.size(), true); newAuxStateId.set(auxStateId.size(), true);
res.back().auxStateId = newAuxStateId; res.back().auxStateId = newAuxStateId;
return res; return res;
@ -75,7 +76,7 @@ namespace storm {
// Initialize a FIFO Queue that stores the start and the end of each row group // Initialize a FIFO Queue that stores the start and the end of each row group
std::queue<BinaryPomdpTransformerRowGroup> queue; std::queue<BinaryPomdpTransformerRowGroup> queue;
for (uint64_t state = 0; state < matrix.getRowGroupCount(); ++state) { for (uint64_t state = 0; state < matrix.getRowGroupCount(); ++state) {
queue.emplace(matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state)); queue.emplace(state, matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state));
} }
std::vector<uint32_t> newObservations; std::vector<uint32_t> newObservations;
@ -84,6 +85,8 @@ namespace storm {
uint64_t currRow = 0; uint64_t currRow = 0;
std::vector<uint64_t> origRowToSimpleRowMap(pomdp.getNumberOfChoices(), std::numeric_limits<uint64_t>::max()); std::vector<uint64_t> origRowToSimpleRowMap(pomdp.getNumberOfChoices(), std::numeric_limits<uint64_t>::max());
uint64_t currAuxState = queue.size(); uint64_t currAuxState = queue.size();
std::vector<uint64_t> origStates;
while (!queue.empty()) { while (!queue.empty()) {
auto group = std::move(queue.front()); auto group = std::move(queue.front());
queue.pop(); queue.pop();
@ -120,12 +123,14 @@ namespace storm {
} }
} }
// Nothing to be done if group has size zero // Nothing to be done if group has size zero
origStates.push_back(group.origState);
} }
TransformationData result; TransformationData result;
result.simpleMatrix = builder.build(currRow, currAuxState, currAuxState); result.simpleMatrix = builder.build(currRow, currAuxState, currAuxState);
result.simpleObservations = std::move(newObservations); result.simpleObservations = std::move(newObservations);
result.originalToSimpleChoiceMap = std::move(origRowToSimpleRowMap); result.originalToSimpleChoiceMap = std::move(origRowToSimpleRowMap);
result.simpleStateToOriginalState = std::move(origStates);
return result; return result;
} }
@ -136,7 +141,11 @@ namespace storm {
for (auto const& labelName : pomdp.getStateLabeling().getLabels()) { for (auto const& labelName : pomdp.getStateLabeling().getLabels()) {
storm::storage::BitVector newStates = pomdp.getStateLabeling().getStates(labelName); storm::storage::BitVector newStates = pomdp.getStateLabeling().getStates(labelName);
newStates.resize(data.simpleMatrix.getRowGroupCount(), false); newStates.resize(data.simpleMatrix.getRowGroupCount(), false);
for (uint64_t newState = pomdp.getNumberOfStates(); newState < data.simpleMatrix.getRowGroupCount(); ++newState ) {
newStates.set(newState, newStates[data.simpleStateToOriginalState[newState]]);
}
labeling.addLabel(labelName, std::move(newStates)); labeling.addLabel(labelName, std::move(newStates));
} }
return labeling; return labeling;
} }

1
src/storm-pomdp/transformer/BinaryPomdpTransformer.h

@ -21,6 +21,7 @@ namespace storm {
storm::storage::SparseMatrix<ValueType> simpleMatrix; storm::storage::SparseMatrix<ValueType> simpleMatrix;
std::vector<uint32_t> simpleObservations; std::vector<uint32_t> simpleObservations;
std::vector<uint64_t> originalToSimpleChoiceMap; std::vector<uint64_t> originalToSimpleChoiceMap;
std::vector<uint64_t> simpleStateToOriginalState;
}; };
TransformationData transformTransitions(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const; TransformationData transformTransitions(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const;

|||||||
100:0
Loading…
Cancel
Save