|
@ -11,8 +11,8 @@ namespace storm { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels) |
|
|
|
|
|
: pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels) { |
|
|
|
|
|
|
|
|
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels, bool keepStateValuations) |
|
|
|
|
|
: pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels), keepStateValuations(keepStateValuations) { |
|
|
// intentionally left empty
|
|
|
// intentionally left empty
|
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@ -30,6 +30,13 @@ namespace storm { |
|
|
auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); |
|
|
auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); |
|
|
components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); |
|
|
components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); |
|
|
components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); |
|
|
components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); |
|
|
|
|
|
if (keepStateValuations && pomdp.hasStateValuations()) { |
|
|
|
|
|
std::vector<uint64_t> newToOldStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), 0); |
|
|
|
|
|
for (uint64_t newState = 0; newState < newToOldStates.size(); newState++) { |
|
|
|
|
|
newToOldStates[newState] = getModelState(newState); |
|
|
|
|
|
} |
|
|
|
|
|
components.stateValuations = pomdp.getStateValuations().blowup(newToOldStates).selectStates(reachableStates); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// build the remaining components
|
|
|
// build the remaining components
|
|
|
components.observabilityClasses = transformObservabilityClasses(reachableStates); |
|
|
components.observabilityClasses = transformObservabilityClasses(reachableStates); |
|
|