Browse Source

Added state valuations in unfolding.

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
a7c6b39f19
  1. 14
      src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp
  2. 3
      src/storm-pomdp/transformer/ObservationTraceUnfolder.h

14
src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp

@ -1,11 +1,13 @@
#include "storm/exceptions/InvalidArgumentException.h"
#include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
#include "storm/storage/expressions/ExpressionManager.h"
namespace storm {
namespace pomdp {
template<typename ValueType>
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model) : model(model) {
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model,
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager) : model(model), exprManager(exprManager) {
statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates()));
for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) {
statesPerObservation[model.getObservation(state)].set(state, true);
@ -32,6 +34,9 @@ namespace storm {
statesPerObservation.resize(model.getNrObservations() + 1);
statesPerObservation[model.getNrObservations()] = actualInitialStates;
storm::storage::sparse::StateValuationsBuilder svbuilder;
auto svvar = exprManager->declareFreshIntegerVariable(false, "_s");
svbuilder.addVariable(svvar);
std::map<uint64_t,uint64_t> unfoldedToOld;
std::map<uint64_t,uint64_t> unfoldedToOldNextStep;
@ -56,6 +61,7 @@ namespace storm {
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl;
assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1);
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second];
uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1];
@ -111,6 +117,8 @@ namespace storm {
uint64_t sinkState = newStateIndex;
uint64_t targetState = newStateIndex + 1;
for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) {
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) {
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState,
@ -125,10 +133,13 @@ namespace storm {
// sink state
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>());
svbuilder.addState(sinkState, {}, {-1});
newRowGroupStart++;
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
// target state
transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>());
svbuilder.addState(targetState, {}, {-1});
@ -144,6 +155,7 @@ namespace storm {
labeling.addLabel("init");
labeling.addLabelToState("init", 0);
components.stateLabeling = labeling;
components.stateValuations = svbuilder.build( components.transitionMatrix.getRowGroupCount());
return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components));

3
src/storm-pomdp/transformer/ObservationTraceUnfolder.h

@ -6,10 +6,11 @@ namespace storm {
class ObservationTraceUnfolder {
public:
ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model);
ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager);
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> transform(std::vector<uint32_t> const& observations, std::vector<ValueType> const& risk);
private:
storm::models::sparse::Pomdp<ValueType> const& model;
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager;
std::vector<storm::storage::BitVector> statesPerObservation;
};

Loading…
Cancel
Save