|
|
@ -29,10 +29,11 @@ namespace storm { |
|
|
|
} |
|
|
|
|
|
|
|
struct BinaryPomdpTransformerRowGroup { |
|
|
|
BinaryPomdpTransformerRowGroup(uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) { |
|
|
|
BinaryPomdpTransformerRowGroup(uint64_t origState, uint64_t firstRow, uint64_t endRow, uint32_t origStateObservation) : origState(origState), firstRow(firstRow), endRow(endRow), origStateObservation(origStateObservation) { |
|
|
|
// Intentionally left empty.
|
|
|
|
} |
|
|
|
|
|
|
|
uint64_t origState; |
|
|
|
uint64_t firstRow; |
|
|
|
uint64_t endRow; |
|
|
|
uint32_t origStateObservation; |
|
|
@ -46,11 +47,11 @@ namespace storm { |
|
|
|
assert(size() > 1); |
|
|
|
uint64_t midRow = firstRow + size()/2; |
|
|
|
std::vector<BinaryPomdpTransformerRowGroup> res; |
|
|
|
res.emplace_back(firstRow, midRow, origStateObservation); |
|
|
|
res.emplace_back(origState, firstRow, midRow, origStateObservation); |
|
|
|
storm::storage::BitVector newAuxStateId = auxStateId; |
|
|
|
newAuxStateId.resize(auxStateId.size() + 1, false); |
|
|
|
res.back().auxStateId = newAuxStateId; |
|
|
|
res.emplace_back(midRow, endRow, origStateObservation); |
|
|
|
res.emplace_back(origState, midRow, endRow, origStateObservation); |
|
|
|
newAuxStateId.set(auxStateId.size(), true); |
|
|
|
res.back().auxStateId = newAuxStateId; |
|
|
|
return res; |
|
|
@ -75,7 +76,7 @@ namespace storm { |
|
|
|
// Initialize a FIFO Queue that stores the start and the end of each row group
|
|
|
|
std::queue<BinaryPomdpTransformerRowGroup> queue; |
|
|
|
for (uint64_t state = 0; state < matrix.getRowGroupCount(); ++state) { |
|
|
|
queue.emplace(matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state)); |
|
|
|
queue.emplace(state, matrix.getRowGroupIndices()[state], matrix.getRowGroupIndices()[state+1], pomdp.getObservation(state)); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<uint32_t> newObservations; |
|
|
@ -84,6 +85,8 @@ namespace storm { |
|
|
|
uint64_t currRow = 0; |
|
|
|
std::vector<uint64_t> origRowToSimpleRowMap(pomdp.getNumberOfChoices(), std::numeric_limits<uint64_t>::max()); |
|
|
|
uint64_t currAuxState = queue.size(); |
|
|
|
std::vector<uint64_t> origStates; |
|
|
|
|
|
|
|
while (!queue.empty()) { |
|
|
|
auto group = std::move(queue.front()); |
|
|
|
queue.pop(); |
|
|
@ -120,12 +123,14 @@ namespace storm { |
|
|
|
} |
|
|
|
} |
|
|
|
// Nothing to be done if group has size zero
|
|
|
|
origStates.push_back(group.origState); |
|
|
|
} |
|
|
|
|
|
|
|
TransformationData result; |
|
|
|
result.simpleMatrix = builder.build(currRow, currAuxState, currAuxState); |
|
|
|
result.simpleObservations = std::move(newObservations); |
|
|
|
result.originalToSimpleChoiceMap = std::move(origRowToSimpleRowMap); |
|
|
|
result.simpleStateToOriginalState = std::move(origStates); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
@ -136,7 +141,11 @@ namespace storm { |
|
|
|
for (auto const& labelName : pomdp.getStateLabeling().getLabels()) { |
|
|
|
storm::storage::BitVector newStates = pomdp.getStateLabeling().getStates(labelName); |
|
|
|
newStates.resize(data.simpleMatrix.getRowGroupCount(), false); |
|
|
|
for (uint64_t newState = pomdp.getNumberOfStates(); newState < data.simpleMatrix.getRowGroupCount(); ++newState ) { |
|
|
|
newStates.set(newState, newStates[data.simpleStateToOriginalState[newState]]); |
|
|
|
} |
|
|
|
labeling.addLabel(labelName, std::move(newStates)); |
|
|
|
|
|
|
|
} |
|
|
|
return labeling; |
|
|
|
} |
|
|
|