Browse Source

keep state valuations in transformers on POMDPs

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
fa34f44989
  1. 5
      src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp
  2. 4
      src/storm-pomdp/transformer/BinaryPomdpTransformer.h
  3. 11
      src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp
  4. 3
      src/storm-pomdp/transformer/PomdpMemoryUnfolder.h
  5. 8
      src/storm/storage/sparse/StateValuations.cpp
  6. 2
      src/storm/storage/sparse/StateValuations.h

5
src/storm-pomdp/transformer/BinaryPomdpTransformer.cpp

@ -15,7 +15,7 @@ namespace storm {
}
template<typename ValueType>
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> BinaryPomdpTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const {
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> BinaryPomdpTransformer<ValueType>::transform(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple, bool keepStateValuations) const {
auto data = transformTransitions(pomdp, transformSimple);
storm::storage::sparse::ModelComponents<ValueType> components;
components.stateLabeling = transformStateLabeling(pomdp, data);
@ -24,6 +24,9 @@ namespace storm {
}
components.transitionMatrix = std::move(data.simpleMatrix);
components.observabilityClasses = std::move(data.simpleObservations);
if (keepStateValuations && pomdp.hasStateValuations()) {
components.stateValuations = pomdp.getStateValuations().blowup(data.simpleStateToOriginalState);
}
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(std::move(components), true);
}

4
src/storm-pomdp/transformer/BinaryPomdpTransformer.h

@ -13,7 +13,7 @@ namespace storm {
BinaryPomdpTransformer();
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const;
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple, bool keepStateValuations = false) const;
private:
@ -27,7 +27,7 @@ namespace storm {
TransformationData transformTransitions(storm::models::sparse::Pomdp<ValueType> const& pomdp, bool transformSimple) const;
storm::models::sparse::StateLabeling transformStateLabeling(storm::models::sparse::Pomdp<ValueType> const& pomdp, TransformationData const& data) const;
storm::models::sparse::StandardRewardModel<ValueType> transformRewardModel(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::models::sparse::StandardRewardModel<ValueType> const& rewardModel, TransformationData const& data) const;
storm::models::sparse::ChoiceLabeling transformChoiceLabeling(storm::models::sparse::Pomdp<ValueType> const& pomdp, TransformationData const& data) const;
};
}
}

11
src/storm-pomdp/transformer/PomdpMemoryUnfolder.cpp

@ -11,8 +11,8 @@ namespace storm {
template<typename ValueType>
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels)
: pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels) {
PomdpMemoryUnfolder<ValueType>::PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels, bool keepStateValuations)
: pomdp(pomdp), memory(memory), addMemoryLabels(addMemoryLabels), keepStateValuations(keepStateValuations) {
// intentionally left empty
}
@ -30,6 +30,13 @@ namespace storm {
auto reachableStates = storm::utility::graph::getReachableStates(components.transitionMatrix, components.stateLabeling.getStates("init"), allStates, ~allStates);
components.transitionMatrix = components.transitionMatrix.getSubmatrix(true, reachableStates, reachableStates);
components.stateLabeling = components.stateLabeling.getSubLabeling(reachableStates);
if (keepStateValuations && pomdp.hasStateValuations()) {
std::vector<uint64_t> newToOldStates(pomdp.getNumberOfStates() * memory.getNumberOfStates(), 0);
for (uint64_t newState = 0; newState < newToOldStates.size(); newState++) {
newToOldStates[newState] = getModelState(newState);
}
components.stateValuations = pomdp.getStateValuations().blowup(newToOldStates).selectStates(reachableStates);
}
// build the remaining components
components.observabilityClasses = transformObservabilityClasses(reachableStates);

3
src/storm-pomdp/transformer/PomdpMemoryUnfolder.h

@ -12,7 +12,7 @@ namespace storm {
public:
PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels = false);
PomdpMemoryUnfolder(storm::models::sparse::Pomdp<ValueType> const& pomdp, storm::storage::PomdpMemory const& memory, bool addMemoryLabels = false, bool keepStateValuations = false);
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform() const;
@ -35,6 +35,7 @@ namespace storm {
storm::storage::PomdpMemory const& memory;
bool addMemoryLabels;
bool keepStateValuations;
};
}
}

8
src/storm/storage/sparse/StateValuations.cpp

@ -248,6 +248,14 @@ namespace storm {
return StateValuations(variableToIndexMap, std::move(selectedValuations));
}
StateValuations StateValuations::blowup(const std::vector<uint64_t> &mapNewToOld) const {
std::vector<StateValuation> newValuations;
for( auto const& oldState : mapNewToOld) {
newValuations.push_back(valuations[oldState]);
}
return StateValuations(variableToIndexMap, std::move(newValuations));
}
StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0) {
// Intentionally left empty.
}

2
src/storm/storage/sparse/StateValuations.h

@ -109,6 +109,8 @@ namespace storm {
*/
StateValuations selectStates(std::vector<storm::storage::sparse::state_type> const& selectedStates) const;
StateValuations blowup(std::vector<uint64_t> const& mapNewToOld) const;
virtual std::size_t hash() const;
private:

Loading…
Cancel
Save