|
@ -1,11 +1,13 @@ |
|
|
#include "storm/exceptions/InvalidArgumentException.h"
|
|
|
#include "storm/exceptions/InvalidArgumentException.h"
|
|
|
#include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
|
|
|
#include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
|
|
|
|
|
|
#include "storm/storage/expressions/ExpressionManager.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace storm { |
|
|
namespace storm { |
|
|
namespace pomdp { |
|
|
namespace pomdp { |
|
|
template<typename ValueType> |
|
|
template<typename ValueType> |
|
|
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model) : model(model) { |
|
|
|
|
|
|
|
|
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model, |
|
|
|
|
|
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager) : model(model), exprManager(exprManager) { |
|
|
statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates())); |
|
|
statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates())); |
|
|
for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) { |
|
|
for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) { |
|
|
statesPerObservation[model.getObservation(state)].set(state, true); |
|
|
statesPerObservation[model.getObservation(state)].set(state, true); |
|
@ -32,6 +34,9 @@ namespace storm { |
|
|
statesPerObservation.resize(model.getNrObservations() + 1); |
|
|
statesPerObservation.resize(model.getNrObservations() + 1); |
|
|
statesPerObservation[model.getNrObservations()] = actualInitialStates; |
|
|
statesPerObservation[model.getNrObservations()] = actualInitialStates; |
|
|
|
|
|
|
|
|
|
|
|
storm::storage::sparse::StateValuationsBuilder svbuilder; |
|
|
|
|
|
auto svvar = exprManager->declareFreshIntegerVariable(false, "_s"); |
|
|
|
|
|
svbuilder.addVariable(svvar); |
|
|
|
|
|
|
|
|
std::map<uint64_t,uint64_t> unfoldedToOld; |
|
|
std::map<uint64_t,uint64_t> unfoldedToOld; |
|
|
std::map<uint64_t,uint64_t> unfoldedToOldNextStep; |
|
|
std::map<uint64_t,uint64_t> unfoldedToOldNextStep; |
|
@ -56,6 +61,7 @@ namespace storm { |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl; |
|
|
std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl; |
|
|
assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1); |
|
|
assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1); |
|
|
|
|
|
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)}); |
|
|
uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second]; |
|
|
uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second]; |
|
|
uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1]; |
|
|
uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1]; |
|
|
|
|
|
|
|
@ -111,6 +117,8 @@ namespace storm { |
|
|
uint64_t sinkState = newStateIndex; |
|
|
uint64_t sinkState = newStateIndex; |
|
|
uint64_t targetState = newStateIndex + 1; |
|
|
uint64_t targetState = newStateIndex + 1; |
|
|
for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) { |
|
|
for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) { |
|
|
|
|
|
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)}); |
|
|
|
|
|
|
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) { |
|
|
if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) { |
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, |
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, |
|
@ -125,10 +133,13 @@ namespace storm { |
|
|
// sink state
|
|
|
// sink state
|
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>()); |
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>()); |
|
|
|
|
|
svbuilder.addState(sinkState, {}, {-1}); |
|
|
|
|
|
|
|
|
newRowGroupStart++; |
|
|
newRowGroupStart++; |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
transitionMatrixBuilder.newRowGroup(newRowGroupStart); |
|
|
// target state
|
|
|
// target state
|
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>()); |
|
|
transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>()); |
|
|
|
|
|
svbuilder.addState(targetState, {}, {-1}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -144,6 +155,7 @@ namespace storm { |
|
|
labeling.addLabel("init"); |
|
|
labeling.addLabel("init"); |
|
|
labeling.addLabelToState("init", 0); |
|
|
labeling.addLabelToState("init", 0); |
|
|
components.stateLabeling = labeling; |
|
|
components.stateLabeling = labeling; |
|
|
|
|
|
components.stateValuations = svbuilder.build( components.transitionMatrix.getRowGroupCount()); |
|
|
return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components)); |
|
|
return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components)); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|