Browse Source

observation valuations added

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
9423d01631
  1. 1
      src/storm-pomdp/transformer/MakePOMDPCanonic.cpp
  2. 11
      src/storm/builder/BuilderOptions.cpp
  3. 12
      src/storm/builder/BuilderOptions.h
  4. 62
      src/storm/builder/ExplicitModelBuilder.cpp
  5. 57
      src/storm/generator/NextStateGenerator.cpp
  6. 4
      src/storm/generator/NextStateGenerator.h
  7. 7
      src/storm/generator/VariableInformation.cpp
  8. 9
      src/storm/generator/VariableInformation.h
  9. 19
      src/storm/models/sparse/Pomdp.cpp
  10. 12
      src/storm/models/sparse/Pomdp.h
  11. 2
      src/storm/storage/sparse/ModelComponents.h
  12. 116
      src/storm/storage/sparse/StateValuations.cpp
  13. 38
      src/storm/storage/sparse/StateValuations.h

1
src/storm-pomdp/transformer/MakePOMDPCanonic.cpp

@ -151,6 +151,7 @@ namespace storm {
modelcomponents.stateValuations = pomdp.getOptionalStateValuations();
modelcomponents.choiceLabeling = pomdp.getChoiceLabeling();
modelcomponents.choiceLabeling->permuteItems(permutation);
modelcomponents.observationValuations = pomdp.getOptionalObservationValuations();
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(modelcomponents, true);
}

11
src/storm/builder/BuilderOptions.cpp

@ -131,10 +131,14 @@ namespace storm {
return buildChoiceLabels;
}
bool BuilderOptions::isBuildStateValuationsSet() const {
bool BuilderOptions::isBuildStateValuationsSet() const {
return buildStateValuations;
}
bool BuilderOptions::isBuildObservationValuationsSet() const {
return buildObservationValuations;
}
bool BuilderOptions::isBuildChoiceOriginsSet() const {
return buildChoiceOrigins;
}
@ -238,6 +242,11 @@ namespace storm {
return *this;
}
BuilderOptions& BuilderOptions::setBuildObservationValuations(bool newValue) {
buildObservationValuations = newValue;
return *this;
}
BuilderOptions& BuilderOptions::setBuildChoiceOrigins(bool newValue) {
buildChoiceOrigins = newValue;
return *this;

12
src/storm/builder/BuilderOptions.h

@ -106,6 +106,7 @@ namespace storm {
bool isApplyMaximalProgressAssumptionSet() const;
bool isBuildChoiceLabelsSet() const;
bool isBuildStateValuationsSet() const;
bool isBuildObservationValuationsSet() const;
bool isBuildChoiceOriginsSet() const;
bool isBuildAllRewardModelsSet() const;
bool isBuildAllLabelsSet() const;
@ -159,6 +160,14 @@ namespace storm {
* @return this
*/
BuilderOptions& setBuildStateValuations(bool newValue = true);
/**
* Should a observation valuation mapping be built?
* @param newValue The new value (default true)
* @return this
*/
BuilderOptions& setBuildObservationValuations(bool newValue = true);
/**
* Should the origins the different choices be built?
* @param newValue The new value (default true)
@ -237,6 +246,9 @@ namespace storm {
/// A flag indicating whether or not to build for each state the variable valuation from which it originates.
bool buildStateValuations;
/// A flag indicating whether or not to build observation valuations
bool buildObservationValuations;
// A flag that indicates whether or not to generate the information from which parts of the model specification
// each choice originates.
bool buildChoiceOrigins;

62
src/storm/builder/ExplicitModelBuilder.cpp

@ -358,69 +358,17 @@ namespace storm {
}
if (generator->isPartiallyObservable()) {
std::vector<uint32_t> classes;
uint32_t newObservation = 0;
classes.resize(stateStorage.getNumberOfStates());
std::unordered_map<uint32_t, std::vector<std::pair<std::vector<std::string>, uint32_t>>> observationActions;
for (auto const& bitVectorIndexPair : stateStorage.stateToId) {
uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first);
uint32_t observation = -1; // Is replaced later on.
bool checkActionNames = false;
if (checkActionNames) {
bool foundActionSet = false;
std::vector<std::string> actionNames;
bool addedAnonymousAction = false;
for (uint64_t choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second];
choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second +
1]; ++choice) {
if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) {
STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException,
"Cannot have multiple anonymous actions, as these cannot be mapped correctly.");
actionNames.push_back("");
addedAnonymousAction = true;
} else {
STORM_LOG_ASSERT(
modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1,
"Expect choice labelling to contain exactly one label at this point, but found "
<< modelComponents.choiceLabeling.get().getLabelsOfChoice(
choice).size());
actionNames.push_back(
*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin());
}
}
STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: "
<< storm::utility::vector::toString(actionNames));
auto it = observationActions.find(varObservation);
if (it == observationActions.end()) {
observationActions.emplace(varObservation,
std::vector<std::pair<std::vector<std::string>, uint32_t>>());
} else {
for (auto const &entries : it->second) {
STORM_LOG_TRACE(storm::utility::vector::toString(entries.first));
if (entries.first == actionNames) {
observation = entries.second;
foundActionSet = true;
break;
}
}
STORM_LOG_THROW(
generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet,
storm::exceptions::WrongFormatException,
"Two states with the same observation have a different set of enabled actions, this is only allowed with a special option.");
}
if (!foundActionSet) {
observation = newObservation;
observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation);
++newObservation;
}
classes[bitVectorIndexPair.second] = observation;
} else {
classes[bitVectorIndexPair.second] = varObservation;
}
classes[bitVectorIndexPair.second] = varObservation;
}
modelComponents.observabilityClasses = classes;
if(generator->getOptions().isBuildObservationValuationsSet()) {
modelComponents.observationValuations = generator->makeObservationValuation();
}
}
return modelComponents;
}

57
src/storm/generator/NextStateGenerator.cpp

@ -62,6 +62,25 @@ namespace storm {
return result;
}
template<typename ValueType, typename StateType>
storm::storage::sparse::StateValuationsBuilder NextStateGenerator<ValueType, StateType>::initializeObservationValuationsBuilder() const {
storm::storage::sparse::StateValuationsBuilder result;
for (auto const& v : variableInformation.booleanVariables) {
if(v.observable) {
result.addVariable(v.variable);
}
}
for (auto const& v : variableInformation.integerVariables) {
if(v.observable) {
result.addVariable(v.variable);
}
}
for (auto const& l : variableInformation.observationLabels) {
result.addObservationLabel(l.name);
}
return result;
}
template<typename ValueType, typename StateType>
void NextStateGenerator<ValueType, StateType>::load(CompressedState const& state) {
// Since almost all subsequent operations are based on the evaluator, we load the state into it now.
@ -94,6 +113,41 @@ namespace storm {
valuationsBuilder.addState(currentStateIndex, std::move(booleanValues), std::move(integerValues));
}
template<typename ValueType, typename StateType>
storm::storage::sparse::StateValuations NextStateGenerator<ValueType, StateType>::makeObservationValuation() const {
storm::storage::sparse::StateValuationsBuilder valuationsBuilder = initializeObservationValuationsBuilder();
for (auto const& observationEntry : observabilityMap) {
std::vector<bool> booleanValues;
booleanValues.reserve(
variableInformation.booleanVariables.size()); // TODO: use number of observable boolean variables
std::vector<int64_t> integerValues;
integerValues.reserve(variableInformation.locationVariables.size() +
variableInformation.integerVariables.size()); // TODO: use number of observable integer variables
std::vector<int64_t> observationLabelValues;
observationLabelValues.reserve(variableInformation.observationLabels.size());
expressions::SimpleValuation val = unpackStateIntoValuation(observationEntry.first, variableInformation, *expressionManager);
for (auto const& v : variableInformation.booleanVariables) {
if (v.observable) {
booleanValues.push_back(val.getBooleanValue(v.variable));
}
}
for (auto const& v : variableInformation.integerVariables) {
if (v.observable) {
integerValues.push_back(val.getIntegerValue(v.variable));
}
}
for(uint64_t labelStart = variableInformation.getTotalBitOffset(true); labelStart < observationEntry.first.size(); labelStart += 64) {
observationLabelValues.push_back(observationEntry.first.getAsInt(labelStart, 64));
}
valuationsBuilder.addState(observationEntry.second, std::move(booleanValues), std::move(integerValues), {}, std::move(observationLabelValues));
}
return valuationsBuilder.build(observabilityMap.size());
}
template<typename ValueType, typename StateType>
storm::models::sparse::StateLabeling NextStateGenerator<ValueType, StateType>::label(storm::storage::sparse::StateStorage<StateType> const& stateStorage, std::vector<StateType> const& initialStateIndices, std::vector<StateType> const& deadlockStateIndices, std::vector<std::pair<std::string, storm::expressions::Expression>> labelsAndExpressions) {
@ -213,7 +267,8 @@ namespace storm {
if (this->mask.size() == 0) {
this->mask = computeObservabilityMask(variableInformation);
}
return unpackStateToObservabilityClass(state, evaluateObservationLabels(state), observabilityMap, mask);
uint32_t classId = unpackStateToObservabilityClass(state, evaluateObservationLabels(state), observabilityMap, mask);
return classId;
}
template<typename ValueType, typename StateType>

4
src/storm/generator/NextStateGenerator.h

@ -64,6 +64,8 @@ namespace storm {
/// Adds the valuation for the currently loaded state to the given builder
virtual void addStateValuation(storm::storage::sparse::state_type const& currentStateIndex, storm::storage::sparse::StateValuationsBuilder& valuationsBuilder) const;
/// Adds the valuation for the currently loaded state
virtual storm::storage::sparse::StateValuations makeObservationValuation() const;
virtual std::size_t getNumberOfRewardModels() const = 0;
virtual storm::builder::RewardModelInformation getRewardModelInformation(uint64_t const& index) const = 0;
@ -95,6 +97,8 @@ namespace storm {
virtual storm::storage::BitVector evaluateObservationLabels(CompressedState const& state) const =0;
virtual storm::storage::sparse::StateValuationsBuilder initializeObservationValuationsBuilder() const;
void postprocess(StateBehavior<ValueType, StateType>& result);
/// The options to be used for next-state generation.

7
src/storm/generator/VariableInformation.cpp

@ -30,6 +30,10 @@ namespace storm {
// Intentionally left empty.
}
ObservationLabelInformation::ObservationLabelInformation(const std::string &name) : name(name) {
// Intentionally left empty.
}
VariableInformation::VariableInformation(storm::prism::Program const& program, bool outOfBoundsState) : totalBitOffset(0) {
if (outOfBoundsState) {
outOfBoundsBit = 0;
@ -64,6 +68,9 @@ namespace storm {
totalBitOffset += bitwidth;
}
}
for (auto const& oblab : program.getObservationLabels()) {
observationLabels.emplace_back(oblab.getName());
}
sortVariables();
}

9
src/storm/generator/VariableInformation.h

@ -90,6 +90,12 @@ namespace storm {
bool observable;
};
struct ObservationLabelInformation {
ObservationLabelInformation(std::string const& name);
std::string name;
bool deterministic = true;
};
// A structure storing information about the used variables of the program.
struct VariableInformation {
VariableInformation(storm::prism::Program const& program, bool outOfBoundsState = false);
@ -114,6 +120,9 @@ namespace storm {
/// The integer variables.
std::vector<IntegerVariableInformation> integerVariables;
/// The observation labels
std::vector<ObservationLabelInformation> observationLabels;
/// Replacements for each array variable
std::unordered_map<storm::expressions::Variable, std::vector<uint64_t>> arrayVariableToElementInformations;

19
src/storm/models/sparse/Pomdp.cpp

@ -15,12 +15,12 @@ namespace storm {
}
template <typename ValueType, typename RewardModelType>
Pomdp<ValueType, RewardModelType>::Pomdp(storm::storage::sparse::ModelComponents<ValueType, RewardModelType> const &components, bool canonicFlag) : Mdp<ValueType, RewardModelType>(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) {
Pomdp<ValueType, RewardModelType>::Pomdp(storm::storage::sparse::ModelComponents<ValueType, RewardModelType> const &components, bool canonicFlag) : Mdp<ValueType, RewardModelType>(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) , observationValuations(components.observationValuations) {
computeNrObservations();
}
template <typename ValueType, typename RewardModelType>
Pomdp<ValueType, RewardModelType>::Pomdp(storm::storage::sparse::ModelComponents<ValueType, RewardModelType> &&components, bool canonicFlag): Mdp<ValueType, RewardModelType>(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) {
Pomdp<ValueType, RewardModelType>::Pomdp(storm::storage::sparse::ModelComponents<ValueType, RewardModelType> &&components, bool canonicFlag): Mdp<ValueType, RewardModelType>(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) , observationValuations(components.observationValuations) {
computeNrObservations();
}
@ -100,6 +100,21 @@ namespace storm {
return result;
}
template<typename ValueType, typename RewardModelType>
bool Pomdp<ValueType, RewardModelType>::hasObservationValuations() const {
return static_cast<bool>(observationValuations);
}
template<typename ValueType, typename RewardModelType>
storm::storage::sparse::StateValuations const& Pomdp<ValueType, RewardModelType>::getObservationValuations() const {
return observationValuations.get();
}
template<typename ValueType, typename RewardModelType>
boost::optional<storm::storage::sparse::StateValuations> const& Pomdp<ValueType, RewardModelType>::getOptionalObservationValuations() const {
return observationValuations;
}
template<typename ValueType, typename RewardModelType>
bool Pomdp<ValueType, RewardModelType>::isCanonic() const {
return canonicFlag;

12
src/storm/models/sparse/Pomdp.h

@ -78,6 +78,12 @@ namespace storm {
std::vector<uint64_t> getStatesWithObservation(uint32_t observation) const;
bool hasObservationValuations() const;
storm::storage::sparse::StateValuations const& getObservationValuations() const;
boost::optional<storm::storage::sparse::StateValuations> const& getOptionalObservationValuations() const;
bool isCanonic() const;
void setIsCanonic(bool newValue = true);
@ -94,11 +100,13 @@ namespace storm {
// TODO: consider a bitvector based presentation (depending on our needs).
std::vector<uint32_t> observations;
uint64_t nrObservations;
bool canonicFlag = false;
boost::optional<storm::storage::sparse::StateValuations> observationValuations;
void computeNrObservations();
};
}

2
src/storm/storage/sparse/ModelComponents.h

@ -67,6 +67,8 @@ namespace storm {
// The POMDP observations
boost::optional<std::vector<uint32_t>> observabilityClasses;
boost::optional<storm::storage::sparse::StateValuations> observationValuations;
// Continuous time specific components (CTMCs, Markov Automata):
// True iff the transition values (for Markovian choices) are interpreted as rates.
bool rateTransitions;

116
src/storm/storage/sparse/StateValuations.cpp

@ -10,7 +10,7 @@ namespace storm {
namespace storage {
namespace sparse {
StateValuations::StateValuation::StateValuation(std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues) : booleanValues(std::move(booleanValues)), integerValues(std::move(integerValues)), rationalValues(std::move(rationalValues)) {
StateValuations::StateValuation::StateValuation(std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues, std::vector<int64_t>&& observationLabelValues) : booleanValues(std::move(booleanValues)), integerValues(std::move(integerValues)), rationalValues(std::move(rationalValues)), observationLabelValues(std::move(observationLabelValues)) {
// Intentionally left empty
}
@ -20,14 +20,45 @@ namespace storm {
return valuations[stateIndex];
}
StateValuations::StateValueIterator::StateValueIterator(typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableIt, StateValuation const* valuation) : variableIt(variableIt), valuation(valuation) {
StateValuations::StateValueIterator::StateValueIterator(typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableIt,
typename std::map<std::string, uint64_t>::const_iterator labelIt,
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableBegin ,
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableEnd,
typename std::map<std::string, uint64_t>::const_iterator labelBegin,
typename std::map<std::string, uint64_t>::const_iterator labelEnd,
StateValuation const* valuation) : variableIt(variableIt), labelIt(labelIt),
variableBegin(variableBegin), variableEnd(variableEnd),
labelBegin(labelBegin), labelEnd(labelEnd), valuation(valuation) {
// Intentionally left empty.
}
storm::expressions::Variable const& StateValuations::StateValueIterator::getVariable() const { return variableIt->first; }
bool StateValuations::StateValueIterator::isBoolean() const { return getVariable().hasBooleanType(); }
bool StateValuations::StateValueIterator::isInteger() const { return getVariable().hasIntegerType(); }
bool StateValuations::StateValueIterator::isRational() const { return getVariable().hasRationalType(); }
bool StateValuations::StateValueIterator::isVariableAssignment() const {
return variableIt != variableEnd;
}
bool StateValuations::StateValueIterator::isLabelAssignment() const {
return variableIt == variableEnd;
}
storm::expressions::Variable const& StateValuations::StateValueIterator::getVariable() const {
STORM_LOG_ASSERT(isVariableAssignment(), "Does not point to a variable");
return variableIt->first;
}
std::string const& StateValuations::StateValueIterator::getLabel() const {
STORM_LOG_ASSERT(isLabelAssignment(), "Does not point to a label");
return labelIt->first;
}
bool StateValuations::StateValueIterator::isBoolean() const { return isVariableAssignment() && getVariable().hasBooleanType(); }
bool StateValuations::StateValueIterator::isInteger() const { return isVariableAssignment() && getVariable().hasIntegerType(); }
bool StateValuations::StateValueIterator::isRational() const { return isVariableAssignment() && getVariable().hasRationalType(); }
std::string const& StateValuations::StateValueIterator::getName() const {
if(isVariableAssignment()) {
return getVariable().getName();
} else {
return getLabel();
}
}
bool StateValuations::StateValueIterator::getBooleanValue() const {
STORM_LOG_ASSERT(isBoolean(), "Variable has no boolean type.");
@ -39,6 +70,12 @@ namespace storm {
return valuation->integerValues[variableIt->second];
}
int64_t StateValuations::StateValueIterator::getLabelValue() const {
STORM_LOG_ASSERT(isLabelAssignment(), "Not a label assignment");
STORM_LOG_ASSERT(labelIt->second < valuation->observationLabelValues.size(), "Label index " << labelIt->second << " larger than number of labels " << valuation->observationLabelValues.size());
return valuation->observationLabelValues[labelIt->second];
}
storm::RationalNumber StateValuations::StateValueIterator::getRationalValue() const {
STORM_LOG_ASSERT(isRational(), "Variable has no rational type.");
return valuation->rationalValues[variableIt->second];
@ -46,33 +83,41 @@ namespace storm {
bool StateValuations::StateValueIterator::operator==(StateValueIterator const& other) {
STORM_LOG_ASSERT(valuation == valuation, "Comparing iterators for different states");
return variableIt == other.variableIt;
return variableIt == other.variableIt && labelIt == other.labelIt;
}
bool StateValuations::StateValueIterator::operator!=(StateValueIterator const& other) {
STORM_LOG_ASSERT(valuation == valuation, "Comparing iterators for different states");
return variableIt != other.variableIt;
return !(*this == other);
}
typename StateValuations::StateValueIterator& StateValuations::StateValueIterator::operator++() {
++variableIt;
if(variableIt != variableEnd ) {
++variableIt;
} else {
++labelIt;
}
return *this;
}
typename StateValuations::StateValueIterator& StateValuations::StateValueIterator::operator--() {
--variableIt;
if (labelIt != labelBegin) {
--labelIt;
} else {
--variableIt;
}
return *this;
}
StateValuations::StateValueIteratorRange::StateValueIteratorRange(std::map<storm::expressions::Variable, uint64_t> const& variableMap, StateValuation const* valuation) : variableMap(variableMap), valuation(valuation) {
StateValuations::StateValueIteratorRange::StateValueIteratorRange(std::map<storm::expressions::Variable, uint64_t> const& variableMap, std::map<std::string, uint64_t> const& labelMap, StateValuation const* valuation) : variableMap(variableMap), labelMap(labelMap), valuation(valuation) {
// Intentionally left empty.
}
StateValuations::StateValueIterator StateValuations::StateValueIteratorRange::begin() const {
return StateValueIterator(variableMap.cbegin(), valuation);
return StateValueIterator(variableMap.cbegin(), labelMap.cbegin(), variableMap.cbegin(), variableMap.cend(), labelMap.cbegin(), labelMap.cend(), valuation);
}
StateValuations::StateValueIterator StateValuations::StateValueIteratorRange::end() const {
return StateValueIterator(variableMap.cend(), valuation);
return StateValueIterator(variableMap.cend(), labelMap.cend(), variableMap.cbegin(), variableMap.cend(), labelMap.cbegin(), labelMap.cend(), valuation);
}
bool StateValuations::getBooleanValue(storm::storage::sparse::state_type const& stateIndex, storm::expressions::Variable const& booleanVariable) const {
@ -94,8 +139,8 @@ namespace storm {
}
bool StateValuations::isEmpty(storm::storage::sparse::state_type const& stateIndex) const {
auto const& valuation = getValuation(stateIndex);
return valuation.booleanValues.empty() && valuation.integerValues.empty() && valuation.rationalValues.empty();
auto const& valuation = valuations[stateIndex]; // Do not use getValuations, as that is only valid after adding stuff.
return valuation.booleanValues.empty() && valuation.integerValues.empty() && valuation.rationalValues.empty() && valuation.observationLabelValues.empty();
}
std::string StateValuations::toString(storm::storage::sparse::state_type const& stateIndex, bool pretty, boost::optional<std::set<storm::expressions::Variable>> const& selectedVariables) const {
@ -115,11 +160,13 @@ namespace storm {
if (valIt.isBoolean() && !valIt.getBooleanValue()) {
stream << "!";
}
stream << valIt.getVariable().getName();
stream << valIt.getName();
if (valIt.isInteger()) {
stream << "=" << valIt.getIntegerValue();
} else if (valIt.isRational()) {
stream << "=" << valIt.getRationalValue();
} else if (valIt.isLabelAssignment()) {
stream << "=" << valIt.getLabelValue();
} else {
STORM_LOG_THROW(valIt.isBoolean(), storm::exceptions::InvalidTypeException, "Unexpected variable type.");
}
@ -130,6 +177,8 @@ namespace storm {
stream << valIt.getIntegerValue();
} else if (valIt.isRational()) {
stream << valIt.getRationalValue();
} else if (valIt.isLabelAssignment()) {
stream << valIt.getLabelValue();
}
}
assignments.push_back(stream.str());
@ -161,12 +210,12 @@ namespace storm {
}
if (valIt.isBoolean()) {
result[valIt.getVariable().getName()] = valIt.getBooleanValue();
result[valIt.getName()] = valIt.getBooleanValue();
} else if (valIt.isInteger()) {
result[valIt.getVariable().getName()] = valIt.getIntegerValue();
result[valIt.getName()] = valIt.getIntegerValue();
} else {
STORM_LOG_ASSERT(valIt.isRational(), "Unexpected variable type.");
result[valIt.getVariable().getName()] = valIt.getRationalValue();
result[valIt.getName()] = valIt.getRationalValue();
}
if (selectedVariables) {
@ -220,7 +269,7 @@ namespace storm {
typename StateValuations::StateValueIteratorRange StateValuations::at(state_type const& state) const {
STORM_LOG_ASSERT(state < getNumberOfStates(), "Invalid state index.");
return StateValueIteratorRange({variableToIndexMap, &(valuations[state])});
return StateValueIteratorRange({variableToIndexMap, observationLabels, &(valuations[state])});
}
uint_fast64_t StateValuations::getNumberOfStates() const {
@ -248,7 +297,7 @@ namespace storm {
return StateValuations(variableToIndexMap, std::move(selectedValuations));
}
StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0) {
StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0), labelCount(0) {
// Intentionally left empty.
}
@ -266,23 +315,40 @@ namespace storm {
}
}
void StateValuationsBuilder::addState(storm::storage::sparse::state_type const& state, std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues) {
void StateValuationsBuilder::addObservationLabel(const std::string &label) {
currentStateValuations.observationLabels[label] = labelCount++;
}
void StateValuationsBuilder::addState(storm::storage::sparse::state_type const& state, std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues,std::vector<int64_t>&& observationLabelValues) {
if (state > currentStateValuations.valuations.size()) {
currentStateValuations.valuations.resize(state);
}
if (state == currentStateValuations.valuations.size()) {
currentStateValuations.valuations.emplace_back(std::move(booleanValues), std::move(integerValues), std::move(rationalValues));
currentStateValuations.valuations.emplace_back(std::move(booleanValues), std::move(integerValues), std::move(rationalValues), std::move(observationLabelValues));
} else {
STORM_LOG_ASSERT(currentStateValuations.isEmpty(state), "Adding a valuation to the same state multiple times.");
currentStateValuations.valuations[state] = typename StateValuations::StateValuation(std::move(booleanValues), std::move(integerValues), std::move(rationalValues));
currentStateValuations.valuations[state] = typename StateValuations::StateValuation(std::move(booleanValues), std::move(integerValues), std::move(rationalValues), std::move(observationLabelValues));
}
}
uint64_t StateValuationsBuilder::getBooleanVarCount() const {
return booleanVarCount;
}
uint64_t StateValuationsBuilder::getIntegerVarCount() const {
return integerVarCount;
}
uint64_t StateValuationsBuilder::getLabelCount() const {
return labelCount;
}
StateValuations StateValuationsBuilder::build(std::size_t totalStateCount) {
return std::move(currentStateValuations);
booleanVarCount = 0;
integerVarCount = 0;
rationalVarCount = 0;
labelCount = 0;
}
}
}

38
src/storm/storage/sparse/StateValuations.h

@ -26,45 +26,65 @@ namespace storm {
public:
friend class StateValuations;
StateValuation() = default;
StateValuation(std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues);
StateValuation(std::vector<bool>&& booleanValues, std::vector<int64_t>&& integerValues, std::vector<storm::RationalNumber>&& rationalValues, std::vector<int64_t>&& observationLabelValues = {});
private:
std::vector<bool> booleanValues;
std::vector<int64_t> integerValues;
std::vector<storm::RationalNumber> rationalValues;
std::vector<int64_t> observationLabelValues;
};
class StateValueIterator {
public:
StateValueIterator(typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableIt, StateValuation const* valuation);
StateValueIterator(typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableIt,
typename std::map<std::string, uint64_t>::const_iterator labelIt,
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableBegin ,
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableEnd,
typename std::map<std::string, uint64_t>::const_iterator labelBegin,
typename std::map<std::string, uint64_t>::const_iterator labelEnd,
StateValuation const* valuation);
bool operator==(StateValueIterator const& other);
bool operator!=(StateValueIterator const& other);
StateValueIterator& operator++();
StateValueIterator& operator--();
bool isVariableAssignment() const;
bool isLabelAssignment() const;
storm::expressions::Variable const& getVariable() const;
std::string const& getLabel() const;
bool isBoolean() const;
bool isInteger() const;
bool isRational() const;
std::string const& getName() const;
// These shall only be called if the variable has the correct type
bool getBooleanValue() const;
int64_t getIntegerValue() const;
storm::RationalNumber getRationalValue() const;
int64_t getLabelValue() const;
private:
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableIt;
typename std::map<std::string, uint64_t>::const_iterator labelIt;
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableBegin;
typename std::map<storm::expressions::Variable, uint64_t>::const_iterator variableEnd;
typename std::map<std::string, uint64_t>::const_iterator labelBegin;
typename std::map<std::string, uint64_t>::const_iterator labelEnd;
StateValuation const* const valuation;
};
class StateValueIteratorRange {
public:
StateValueIteratorRange(std::map<storm::expressions::Variable, uint64_t> const& variableMap, StateValuation const* valuation);
StateValueIteratorRange(std::map<storm::expressions::Variable, uint64_t> const& variableMap, std::map<std::string, uint64_t> const& labelMap, StateValuation const* valuation);
StateValueIterator begin() const;
StateValueIterator end() const;
private:
std::map<storm::expressions::Variable, uint64_t> const& variableMap;
std::map<std::string, uint64_t> const& labelMap;
StateValuation const* const valuation;
};
@ -117,6 +137,7 @@ namespace storm {
StateValuation const& getValuation(storm::storage::sparse::state_type const& stateIndex) const;
std::map<storm::expressions::Variable, uint64_t> variableToIndexMap;
std::map<std::string, uint64_t> observationLabels;
// A mapping from state indices to their variable valuations.
std::vector<StateValuation> valuations;
@ -127,28 +148,35 @@ namespace storm {
StateValuationsBuilder();
/*! Adds a new variable to keep track of for the state valuations.
*! All variables need to be added before adding new states.
* All variables need to be added before adding new states.
*/
void addVariable(storm::expressions::Variable const& variable);
void addObservationLabel(std::string const& label);
/*!
* Adds a new state.
* The variable values have to be given in the same order as the variables have been added.
* The number of given variable values for each type needs to match the number of added variables.
* After calling this method, no more variables should be added.
*/
void addState(storm::storage::sparse::state_type const& state, std::vector<bool>&& booleanValues = {}, std::vector<int64_t>&& integerValues = {}, std::vector<storm::RationalNumber>&& rationalValues = {});
void addState(storm::storage::sparse::state_type const& state, std::vector<bool>&& booleanValues = {}, std::vector<int64_t>&& integerValues = {}, std::vector<storm::RationalNumber>&& rationalValues = {}, std::vector<int64_t>&& observationLabelValues = {});
/*!
* Creates the finalized state valuations object.
*/
StateValuations build(std::size_t totalStateCount);
uint64_t getBooleanVarCount() const;
uint64_t getIntegerVarCount() const;
uint64_t getLabelCount() const;
private:
StateValuations currentStateValuations;
uint64_t booleanVarCount;
uint64_t integerVarCount;
uint64_t rationalVarCount;
uint64_t labelCount;
};
}
}

Loading…
Cancel
Save