Browse Source

Added state valuations in unfolding.

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
a7c6b39f19
  1. 14
      src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp
  2. 3
      src/storm-pomdp/transformer/ObservationTraceUnfolder.h

14
src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp

@ -1,11 +1,13 @@
#include "storm/exceptions/InvalidArgumentException.h" #include "storm/exceptions/InvalidArgumentException.h"
#include "storm-pomdp/transformer/ObservationTraceUnfolder.h" #include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
#include "storm/storage/expressions/ExpressionManager.h"
namespace storm { namespace storm {
namespace pomdp { namespace pomdp {
template<typename ValueType> template<typename ValueType>
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model) : model(model) {
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model,
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager) : model(model), exprManager(exprManager) {
statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates())); statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations(), storm::storage::BitVector(model.getNumberOfStates()));
for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) { for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) {
statesPerObservation[model.getObservation(state)].set(state, true); statesPerObservation[model.getObservation(state)].set(state, true);
@ -32,6 +34,9 @@ namespace storm {
statesPerObservation.resize(model.getNrObservations() + 1); statesPerObservation.resize(model.getNrObservations() + 1);
statesPerObservation[model.getNrObservations()] = actualInitialStates; statesPerObservation[model.getNrObservations()] = actualInitialStates;
storm::storage::sparse::StateValuationsBuilder svbuilder;
auto svvar = exprManager->declareFreshIntegerVariable(false, "_s");
svbuilder.addVariable(svvar);
std::map<uint64_t,uint64_t> unfoldedToOld; std::map<uint64_t,uint64_t> unfoldedToOld;
std::map<uint64_t,uint64_t> unfoldedToOldNextStep; std::map<uint64_t,uint64_t> unfoldedToOldNextStep;
@ -56,6 +61,7 @@ namespace storm {
transitionMatrixBuilder.newRowGroup(newRowGroupStart); transitionMatrixBuilder.newRowGroup(newRowGroupStart);
std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl; std::cout << "\tconsider new state " << unfoldedToOldEntry.first << std::endl;
assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1); assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1);
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second]; uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second];
uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1]; uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second+1];
@ -111,6 +117,8 @@ namespace storm {
uint64_t sinkState = newStateIndex; uint64_t sinkState = newStateIndex;
uint64_t targetState = newStateIndex + 1; uint64_t targetState = newStateIndex + 1;
for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) { for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) {
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
transitionMatrixBuilder.newRowGroup(newRowGroupStart); transitionMatrixBuilder.newRowGroup(newRowGroupStart);
if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) { if (!storm::utility::isZero(storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second])) {
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState,
@ -125,10 +133,13 @@ namespace storm {
// sink state // sink state
transitionMatrixBuilder.newRowGroup(newRowGroupStart); transitionMatrixBuilder.newRowGroup(newRowGroupStart);
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>()); transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>());
svbuilder.addState(sinkState, {}, {-1});
newRowGroupStart++; newRowGroupStart++;
transitionMatrixBuilder.newRowGroup(newRowGroupStart); transitionMatrixBuilder.newRowGroup(newRowGroupStart);
// target state // target state
transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>()); transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>());
svbuilder.addState(targetState, {}, {-1});
@ -144,6 +155,7 @@ namespace storm {
labeling.addLabel("init"); labeling.addLabel("init");
labeling.addLabelToState("init", 0); labeling.addLabelToState("init", 0);
components.stateLabeling = labeling; components.stateLabeling = labeling;
components.stateValuations = svbuilder.build( components.transitionMatrix.getRowGroupCount());
return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components)); return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components));

3
src/storm-pomdp/transformer/ObservationTraceUnfolder.h

@ -6,10 +6,11 @@ namespace storm {
class ObservationTraceUnfolder { class ObservationTraceUnfolder {
public: public:
ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model);
ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model, std::shared_ptr<storm::expressions::ExpressionManager>& exprManager);
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> transform(std::vector<uint32_t> const& observations, std::vector<ValueType> const& risk); std::shared_ptr<storm::models::sparse::Mdp<ValueType>> transform(std::vector<uint32_t> const& observations, std::vector<ValueType> const& risk);
private: private:
storm::models::sparse::Pomdp<ValueType> const& model; storm::models::sparse::Pomdp<ValueType> const& model;
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager;
std::vector<storm::storage::BitVector> statesPerObservation; std::vector<storm::storage::BitVector> statesPerObservation;
}; };

Loading…
Cancel
Save