diff --git a/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp b/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp index baace3c5f..d3838dc2f 100644 --- a/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp +++ b/src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp @@ -15,7 +15,7 @@ namespace storm { } template - std::shared_ptr> BinaryPomdpTransformer::transform(storm::models::sparse::Pomdp const& pomdp, bool transformSimple) const { + std::shared_ptr> BinaryPomdpTransformer::transform(storm::models::sparse::Pomdp const& pomdp, bool transformSimple, bool keepStateValuations) const { auto data = transformTransitions(pomdp, transformSimple); storm::storage::sparse::ModelComponents components; components.stateLabeling = transformStateLabeling(pomdp, data); @@ -24,6 +24,9 @@ namespace storm { } components.transitionMatrix = std::move(data.simpleMatrix); components.observabilityClasses = std::move(data.simpleObservations); + if (keepStateValuations && pomdp.hasStateValuations()) { + components.stateValuations = pomdp.getStateValuations().blowup(data.simpleStateToOriginalState); + } return std::make_shared>(std::move(components), true); } diff --git a/src/storm-pomdp/transformer/BinaryPomdpTransformer.h b/src/storm-pomdp/transformer/BinaryPomdpTransformer.h index 1e5fae76f..07090a905 100644 --- a/src/storm-pomdp/transformer/BinaryPomdpTransformer.h +++ b/src/storm-pomdp/transformer/BinaryPomdpTransformer.h @@ -13,7 +13,7 @@ namespace storm { BinaryPomdpTransformer(); - std::shared_ptr> transform(storm::models::sparse::Pomdp const& pomdp, bool transformSimple) const; + std::shared_ptr> transform(storm::models::sparse::Pomdp const& pomdp, bool transformSimple, bool keepStateValuations = false) const; private: @@ -27,7 +27,7 @@ namespace storm { TransformationData transformTransitions(storm::models::sparse::Pomdp const& pomdp, bool transformSimple) const; storm::models::sparse::StateLabeling transformStateLabeling(storm::models::sparse::Pomdp const& pomdp, TransformationData const& data) const; storm::models::sparse::StandardRewardModel transformRewardModel(storm::models::sparse::Pomdp const& pomdp, storm::models::sparse::StandardRewardModel const& rewardModel, TransformationData const& data) const; - + storm::models::sparse::ChoiceLabeling transformChoiceLabeling(storm::models::sparse::Pomdp const& pomdp, TransformationData const& data) const; }; } } \ No newline at end of file diff --git a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp index 6687c30db..8e351caf9 100644 --- a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp +++ b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp @@ -11,8 +11,8 @@ namespace storm { template - PomdpMemoryUnfolder::PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels) - : pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels) { + PomdpMemoryUnfolder::PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels, bool keepStateValuations) + : pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels), keepStateValuations(keepStateValuations) { // intentionally left empty } @@ -30,6 +30,13 @@ namespace storm { auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates); components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates); components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates); + if (keepStateValuations && pomdp.hasStateValuations()) { + std::vector newToOldStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), 0); + for (uint64_t newState = 0; newState < newToOldStates.size(); newState++) { + newToOldStates[newState] = getModelState(newState); + } + components.stateValuations = pomdp.getStateValuations().blowup(newToOldStates).selectStates(reachableStates); + } // build the remaining components components.observabilityClasses = transformObservabilityClasses(reachableStates); diff --git a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h index 41f8d0526..6cb59c3e9 100644 --- a/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h +++ b/src/storm-pomdp/transformer/PomdpMemoryUnfolder.h @@ -12,7 +12,7 @@ namespace storm { public: - PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels = false); + PomdpMemoryUnfolder(storm::models::sparse::Pomdp const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels = false, bool keepStateValuations = false); std::shared_ptr> transform() const; @@ -35,6 +35,7 @@ namespace storm { storm::storage::PomdpMemory const& memory; bool addMemoryLabels; + bool keepStateValuations; }; } } \ No newline at end of file diff --git a/src/storm/storage/sparse/StateValuations.cpp b/src/storm/storage/sparse/StateValuations.cpp index 541e56b90..271e8aee9 100644 --- a/src/storm/storage/sparse/StateValuations.cpp +++ b/src/storm/storage/sparse/StateValuations.cpp @@ -247,6 +247,14 @@ namespace storm { } return StateValuations(variableToIndexMap, std::move(selectedValuations)); } + + StateValuations StateValuations::blowup(const std::vector &mapNewToOld) const { + std::vector newValuations; + for( auto const& oldState : mapNewToOld) { + newValuations.push_back(valuations[oldState]); + } + return StateValuations(variableToIndexMap, std::move(newValuations)); + } StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0) { // Intentionally left empty. diff --git a/src/storm/storage/sparse/StateValuations.h b/src/storm/storage/sparse/StateValuations.h index 07027d52b..4199a69a3 100644 --- a/src/storm/storage/sparse/StateValuations.h +++ b/src/storm/storage/sparse/StateValuations.h @@ -109,6 +109,8 @@ namespace storm { */ StateValuations selectStates(std::vector const& selectedStates) const; + StateValuations blowup(std::vector const& mapNewToOld) const; + virtual std::size_t hash() const; private: