Browse Source

fix for the binary pomdp transformer: labels are now attached to auxiliary states as well

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
b901a1b308
  1. 17
      src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp
  2. 1
      src/storm-pomdp/transformer/BinaryPomdpTransformer.h

17
src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp

@ -29,10 +29,11 @@ namespace storm {
}
struct BinaryPomdpTransformerRowGroup {
BinaryPomdpTransformerRowGroup(uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) {
BinaryPomdpTransformerRowGroup(uint64_t origState, uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : origState(origState), firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) {
// Intentionally left empty.
}
uint64_t origState;
uint64_t firstRow;
uint64_t endRow;
uint32_t origStateObservation;
@ -46,11 +47,11 @@ namespace storm {
assert(size() > 1);
uint64_t midRow = firstRow + size()/2;
std::vector<BinaryPomdpTransformerRowGroup> res;
res.emplace_back(firstRow, midRow, origStateObservation);
res.emplace_back(origState, firstRow, midRow, origStateObservation);
storm::storage::BitVector newAuxStateId = auxStateId;
newAuxStateId.resize(auxStateId.size() + 1, false);
res.back().auxStateId = newAuxStateId;
res.emplace_back(midRow, endRow, origStateObservation);
res.emplace_back(origState, midRow, endRow, origStateObservation);
newAuxStateId.set(auxStateId.size(), true);
res.back().auxStateId = newAuxStateId;
return res;
@ -75,7 +76,7 @@ namespace storm {
// Initialize a FIFO Queue that stores the start and the end of each row group
std::queue<BinaryPomdpTransformerRowGroup> queue;
for (uint64_t state = 0; state < matrix.getRowGroupCount(); ++state) {
queue.emplace(matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state));
queue.emplace(state, matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state));
}
std::vector<uint32_t> newObservations;
@ -84,6 +85,8 @@ namespace storm {
uint64_t currRow = 0;
std::vector<uint64_t> origRowToSimpleRowMap(pomdp.getNumberOfChoices(), std::numeric_limits<uint64_t>::max());
uint64_t currAuxState = queue.size();
std::vector<uint64_t> origStates;
while (!queue.empty()) {
auto group = std::move(queue.front());
queue.pop();
@ -120,12 +123,14 @@ namespace storm {
}
}
// Nothing to be done if group has size zero
origStates.push_back(group.origState);
}
TransformationData result;
result.simpleMatrix = builder.build(currRow, currAuxState, currAuxState);
result.simpleObservations = std::move(newObservations);
result.originalToSimpleChoiceMap = std::move(origRowToSimpleRowMap);
result.simpleStateToOriginalState = std::move(origStates);
return result;
}
@ -136,7 +141,11 @@ namespace storm {
for (auto const& labelName : pomdp.getStateLabeling().getLabels()) {
storm::storage::BitVector newStates = pomdp.getStateLabeling().getStates(labelName);
newStates.resize(data.simpleMatrix.getRowGroupCount(), false);
for (uint64_t newState = pomdp.getNumberOfStates(); newState < data.simpleMatrix.getRowGroupCount(); ++newState ) {
newStates.set(newState, newStates[data.simpleStateToOriginalState[newState]]);
}
labeling.addLabel(labelName, std::move(newStates));
}
return labeling;
}

1
src/storm-pomdp/transformer/BinaryPomdpTransformer.h

@ -21,6 +21,7 @@ namespace storm {
storm::storage::SparseMatrix<ValueType> simpleMatrix;
std::vector<uint32_t> simpleObservations;
std::vector<uint64_t> originalToSimpleChoiceMap;
std::vector<uint64_t> simpleStateToOriginalState;
};
TransformationData transformTransitions(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const;

Loading…
Cancel
Save