diff --git a/src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp b/src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp index a300e1880..1d8ca5399 100644 --- a/src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp +++ b/src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp @@ -1,11 +1,13 @@ #include "storm/exceptions/InvalidArgumentException.h" #include "storm-pomdp/transformer/ObservationTraceUnfolder.h" +#include "storm/storage/expressions/ExpressionManager.h" namespace storm { namespace pomdp { template - ObservationTraceUnfolder::ObservationTraceUnfolder(storm::models::sparse::Pomdp const& model) : model(model) { + ObservationTraceUnfolder::ObservationTraceUnfolder(storm::models::sparse::Pomdp const& model, + std::shared_ptr& exprManager) : model(model), exprManager(exprManager) { statesPerObservation = std::vector(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates())); for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) { statesPerObservation[model.getObservation(state)].set(state, true); @@ -32,6 +34,9 @@ namespace storm { statesPerObservation.resize(model.getNrObservations() + 1); statesPerObservation[model.getNrObservations()] = actualInitialStates; + storm::storage::sparse::StateValuationsBuilder svbuilder; + auto svvar = exprManager->declareFreshIntegerVariable(false, "_s"); + svbuilder.addVariable(svvar); std::map unfoldedToOld; std::map unfoldedToOldNextStep; @@ -56,6 +61,7 @@ namespace storm { transitionMatrixBuilder.newRowGroup(newRowGroupStart); std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl; assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1); + svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast(unfoldedToOldEntry.second)}); uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second]; uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1]; @@ -111,6 +117,8 @@ namespace storm { uint64_t sinkState = newStateIndex; uint64_t targetState = newStateIndex + 1; for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) { + svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast(unfoldedToOldEntry.second)}); + transitionMatrixBuilder.newRowGroup(newRowGroupStart); if (!storm::utility::isZero(storm::utility::one() - risk[unfoldedToOldEntry.second])) { transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, @@ -125,10 +133,13 @@ namespace storm { // sink state transitionMatrixBuilder.newRowGroup(newRowGroupStart); transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one()); + svbuilder.addState(sinkState, {}, {-1}); + newRowGroupStart++; transitionMatrixBuilder.newRowGroup(newRowGroupStart); // target state transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one()); + svbuilder.addState(targetState, {}, {-1}); @@ -144,6 +155,7 @@ namespace storm { labeling.addLabel("init"); labeling.addLabelToState("init", 0); components.stateLabeling = labeling; + components.stateValuations = svbuilder.build( components.transitionMatrix.getRowGroupCount()); return std::make_shared>(std::move(components)); diff --git a/src/storm-pomdp/transformer/ObservationTraceUnfolder.h b/src/storm-pomdp/transformer/ObservationTraceUnfolder.h index feafe57e2..421cb4db9 100644 --- a/src/storm-pomdp/transformer/ObservationTraceUnfolder.h +++ b/src/storm-pomdp/transformer/ObservationTraceUnfolder.h @@ -6,10 +6,11 @@ namespace storm { class ObservationTraceUnfolder { public: - ObservationTraceUnfolder(storm::models::sparse::Pomdp const& model); + ObservationTraceUnfolder(storm::models::sparse::Pomdp const& model, std::shared_ptr& exprManager); std::shared_ptr> transform(std::vector const& observations, std::vector const& risk); private: storm::models::sparse::Pomdp const& model; + std::shared_ptr& exprManager; std::vector statesPerObservation; };