diff --git a/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp b/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp index 28f0d8c30..601ce6699 100644 --- a/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp +++ b/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp @@ -151,6 +151,7 @@ namespace storm { modelcomponents.stateValuations = pomdp.getOptionalStateValuations(); modelcomponents.choiceLabeling = pomdp.getChoiceLabeling(); modelcomponents.choiceLabeling->permuteItems(permutation); + modelcomponents.observationValuations = pomdp.getOptionalObservationValuations(); return std::make_shared>(modelcomponents, true); } diff --git a/src/storm/builder/BuilderOptions.cpp b/src/storm/builder/BuilderOptions.cpp index 9831040c3..17de90355 100644 --- a/src/storm/builder/BuilderOptions.cpp +++ b/src/storm/builder/BuilderOptions.cpp @@ -130,10 +130,14 @@ namespace storm { bool BuilderOptions::isBuildChoiceLabelsSet() const { return buildChoiceLabels; } - - bool BuilderOptions::isBuildStateValuationsSet() const { + + bool BuilderOptions::isBuildStateValuationsSet() const { return buildStateValuations; } + + bool BuilderOptions::isBuildObservationValuationsSet() const { + return buildObservationValuations; + } bool BuilderOptions::isBuildChoiceOriginsSet() const { return buildChoiceOrigins; @@ -237,6 +241,11 @@ namespace storm { buildStateValuations = newValue; return *this; } + + BuilderOptions& BuilderOptions::setBuildObservationValuations(bool newValue) { + buildObservationValuations = newValue; + return *this; + } BuilderOptions& BuilderOptions::setBuildChoiceOrigins(bool newValue) { buildChoiceOrigins = newValue; diff --git a/src/storm/builder/BuilderOptions.h b/src/storm/builder/BuilderOptions.h index a06817816..af424497a 100644 --- a/src/storm/builder/BuilderOptions.h +++ b/src/storm/builder/BuilderOptions.h @@ -106,6 +106,7 @@ namespace storm { bool isApplyMaximalProgressAssumptionSet() const; bool isBuildChoiceLabelsSet() const; bool isBuildStateValuationsSet() const; + bool isBuildObservationValuationsSet() const; bool isBuildChoiceOriginsSet() const; bool isBuildAllRewardModelsSet() const; bool isBuildAllLabelsSet() const; @@ -159,6 +160,14 @@ namespace storm { * @return this */ BuilderOptions& setBuildStateValuations(bool newValue = true); + + /** + * Should a observation valuation mapping be built? + * @param newValue The new value (default true) + * @return this + */ + BuilderOptions& setBuildObservationValuations(bool newValue = true); + /** * Should the origins the different choices be built? * @param newValue The new value (default true) @@ -236,6 +245,9 @@ namespace storm { /// A flag indicating whether or not to build for each state the variable valuation from which it originates. bool buildStateValuations; + + /// A flag indicating whether or not to build observation valuations + bool buildObservationValuations; // A flag that indicates whether or not to generate the information from which parts of the model specification // each choice originates. diff --git a/src/storm/builder/ExplicitModelBuilder.cpp b/src/storm/builder/ExplicitModelBuilder.cpp index 9f5c7d298..8fedf075b 100644 --- a/src/storm/builder/ExplicitModelBuilder.cpp +++ b/src/storm/builder/ExplicitModelBuilder.cpp @@ -358,69 +358,17 @@ namespace storm { } if (generator->isPartiallyObservable()) { std::vector classes; - uint32_t newObservation = 0; classes.resize(stateStorage.getNumberOfStates()); std::unordered_map, uint32_t>>> observationActions; for (auto const& bitVectorIndexPair : stateStorage.stateToId) { uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first); - uint32_t observation = -1; // Is replaced later on. - bool checkActionNames = false; - if (checkActionNames) { - bool foundActionSet = false; - std::vector actionNames; - bool addedAnonymousAction = false; - for (uint64_t choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; - choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second + - 1]; ++choice) { - if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) { - STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, - "Cannot have multiple anonymous actions, as these cannot be mapped correctly."); - actionNames.push_back(""); - addedAnonymousAction = true; - } else { - STORM_LOG_ASSERT( - modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, - "Expect choice labelling to contain exactly one label at this point, but found " - << modelComponents.choiceLabeling.get().getLabelsOfChoice( - choice).size()); - actionNames.push_back( - *modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin()); - } - } - STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " - << storm::utility::vector::toString(actionNames)); - auto it = observationActions.find(varObservation); - if (it == observationActions.end()) { - observationActions.emplace(varObservation, - std::vector, uint32_t>>()); - } else { - for (auto const &entries : it->second) { - STORM_LOG_TRACE(storm::utility::vector::toString(entries.first)); - if (entries.first == actionNames) { - observation = entries.second; - foundActionSet = true; - break; - } - } - - STORM_LOG_THROW( - generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, - storm::exceptions::WrongFormatException, - "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option."); - - } - if (!foundActionSet) { - observation = newObservation; - observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation); - ++newObservation; - } - - classes[bitVectorIndexPair.second] = observation; - } else { - classes[bitVectorIndexPair.second] = varObservation; - } + classes[bitVectorIndexPair.second] = varObservation; } + modelComponents.observabilityClasses = classes; + if(generator->getOptions().isBuildObservationValuationsSet()) { + modelComponents.observationValuations = generator->makeObservationValuation(); + } } return modelComponents; } diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp index eb681382b..e62825d17 100644 --- a/src/storm/generator/NextStateGenerator.cpp +++ b/src/storm/generator/NextStateGenerator.cpp @@ -61,6 +61,25 @@ namespace storm { } return result; } + + template + storm::storage::sparse::StateValuationsBuilder NextStateGenerator::initializeObservationValuationsBuilder() const { + storm::storage::sparse::StateValuationsBuilder result; + for (auto const& v : variableInformation.booleanVariables) { + if(v.observable) { + result.addVariable(v.variable); + } + } + for (auto const& v : variableInformation.integerVariables) { + if(v.observable) { + result.addVariable(v.variable); + } + } + for (auto const& l : variableInformation.observationLabels) { + result.addObservationLabel(l.name); + } + return result; + } template void NextStateGenerator::load(CompressedState const& state) { @@ -93,7 +112,42 @@ namespace storm { extractVariableValues(*this->state, variableInformation, integerValues, booleanValues, integerValues); valuationsBuilder.addState(currentStateIndex, std::move(booleanValues), std::move(integerValues)); } - + + template + storm::storage::sparse::StateValuations NextStateGenerator::makeObservationValuation() const { + storm::storage::sparse::StateValuationsBuilder valuationsBuilder = initializeObservationValuationsBuilder(); + for (auto const& observationEntry : observabilityMap) { + std::vector booleanValues; + booleanValues.reserve( + variableInformation.booleanVariables.size()); // TODO: use number of observable boolean variables + std::vector integerValues; + integerValues.reserve(variableInformation.locationVariables.size() + + variableInformation.integerVariables.size()); // TODO: use number of observable integer variables + std::vector observationLabelValues; + observationLabelValues.reserve(variableInformation.observationLabels.size()); + expressions::SimpleValuation val = unpackStateIntoValuation(observationEntry.first, variableInformation, *expressionManager); + for (auto const& v : variableInformation.booleanVariables) { + if (v.observable) { + booleanValues.push_back(val.getBooleanValue(v.variable)); + } + } + for (auto const& v : variableInformation.integerVariables) { + if (v.observable) { + integerValues.push_back(val.getIntegerValue(v.variable)); + } + } + for(uint64_t labelStart = variableInformation.getTotalBitOffset(true); labelStart < observationEntry.first.size(); labelStart += 64) { + observationLabelValues.push_back(observationEntry.first.getAsInt(labelStart, 64)); + } + valuationsBuilder.addState(observationEntry.second, std::move(booleanValues), std::move(integerValues), {}, std::move(observationLabelValues)); + } + return valuationsBuilder.build(observabilityMap.size()); + + } + + + + template storm::models::sparse::StateLabeling NextStateGenerator::label(storm::storage::sparse::StateStorage const& stateStorage, std::vector const& initialStateIndices, std::vector const& deadlockStateIndices, std::vector> labelsAndExpressions) { @@ -213,7 +267,8 @@ namespace storm { if (this->mask.size() == 0) { this->mask = computeObservabilityMask(variableInformation); } - return unpackStateToObservabilityClass(state, evaluateObservationLabels(state), observabilityMap, mask); + uint32_t classId = unpackStateToObservabilityClass(state, evaluateObservationLabels(state), observabilityMap, mask); + return classId; } template diff --git a/src/storm/generator/NextStateGenerator.h b/src/storm/generator/NextStateGenerator.h index b782ad908..59f50843e 100644 --- a/src/storm/generator/NextStateGenerator.h +++ b/src/storm/generator/NextStateGenerator.h @@ -64,7 +64,9 @@ namespace storm { /// Adds the valuation for the currently loaded state to the given builder virtual void addStateValuation(storm::storage::sparse::state_type const& currentStateIndex, storm::storage::sparse::StateValuationsBuilder& valuationsBuilder) const; - + /// Adds the valuation for the currently loaded state + virtual storm::storage::sparse::StateValuations makeObservationValuation() const; + virtual std::size_t getNumberOfRewardModels() const = 0; virtual storm::builder::RewardModelInformation getRewardModelInformation(uint64_t const& index) const = 0; @@ -95,6 +97,8 @@ namespace storm { virtual storm::storage::BitVector evaluateObservationLabels(CompressedState const& state) const =0; + virtual storm::storage::sparse::StateValuationsBuilder initializeObservationValuationsBuilder() const; + void postprocess(StateBehavior& result); /// The options to be used for next-state generation. diff --git a/src/storm/generator/VariableInformation.cpp b/src/storm/generator/VariableInformation.cpp index 72a351250..91dcf30ee 100644 --- a/src/storm/generator/VariableInformation.cpp +++ b/src/storm/generator/VariableInformation.cpp @@ -29,6 +29,10 @@ namespace storm { LocationVariableInformation::LocationVariableInformation(storm::expressions::Variable const& variable, uint64_t highestValue, uint_fast64_t bitOffset, uint_fast64_t bitWidth, bool observable) : variable(variable), highestValue(highestValue), bitOffset(bitOffset), bitWidth(bitWidth), observable(observable) { // Intentionally left empty. } + + ObservationLabelInformation::ObservationLabelInformation(const std::string &name) : name(name) { + // Intentionally left empty. + } VariableInformation::VariableInformation(storm::prism::Program const& program, bool outOfBoundsState) : totalBitOffset(0) { if (outOfBoundsState) { @@ -64,6 +68,9 @@ namespace storm { totalBitOffset += bitwidth; } } + for (auto const& oblab : program.getObservationLabels()) { + observationLabels.emplace_back(oblab.getName()); + } sortVariables(); } diff --git a/src/storm/generator/VariableInformation.h b/src/storm/generator/VariableInformation.h index b7be52200..be97290f8 100644 --- a/src/storm/generator/VariableInformation.h +++ b/src/storm/generator/VariableInformation.h @@ -89,6 +89,12 @@ namespace storm { bool observable; }; + + struct ObservationLabelInformation { + ObservationLabelInformation(std::string const& name); + std::string name; + bool deterministic = true; + }; // A structure storing information about the used variables of the program. struct VariableInformation { @@ -113,7 +119,10 @@ namespace storm { /// The integer variables. std::vector integerVariables; - + + /// The observation labels + std::vector observationLabels; + /// Replacements for each array variable std::unordered_map> arrayVariableToElementInformations; diff --git a/src/storm/models/sparse/Pomdp.cpp b/src/storm/models/sparse/Pomdp.cpp index dc5461120..659c13e2e 100644 --- a/src/storm/models/sparse/Pomdp.cpp +++ b/src/storm/models/sparse/Pomdp.cpp @@ -15,12 +15,12 @@ namespace storm { } template - Pomdp::Pomdp(storm::storage::sparse::ModelComponents const &components, bool canonicFlag) : Mdp(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) { + Pomdp::Pomdp(storm::storage::sparse::ModelComponents const &components, bool canonicFlag) : Mdp(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) , observationValuations(components.observationValuations) { computeNrObservations(); } template - Pomdp::Pomdp(storm::storage::sparse::ModelComponents &&components, bool canonicFlag): Mdp(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) { + Pomdp::Pomdp(storm::storage::sparse::ModelComponents &&components, bool canonicFlag): Mdp(components, storm::models::ModelType::Pomdp), observations(components.observabilityClasses.get()), canonicFlag(canonicFlag) , observationValuations(components.observationValuations) { computeNrObservations(); } @@ -100,6 +100,21 @@ namespace storm { return result; } + template + bool Pomdp::hasObservationValuations() const { + return static_cast(observationValuations); + } + + template + storm::storage::sparse::StateValuations const& Pomdp::getObservationValuations() const { + return observationValuations.get(); + } + + template + boost::optional const& Pomdp::getOptionalObservationValuations() const { + return observationValuations; + } + template bool Pomdp::isCanonic() const { return canonicFlag; diff --git a/src/storm/models/sparse/Pomdp.h b/src/storm/models/sparse/Pomdp.h index cacee0953..6b1165482 100644 --- a/src/storm/models/sparse/Pomdp.h +++ b/src/storm/models/sparse/Pomdp.h @@ -78,6 +78,12 @@ namespace storm { std::vector getStatesWithObservation(uint32_t observation) const; + bool hasObservationValuations() const; + + storm::storage::sparse::StateValuations const& getObservationValuations() const; + + boost::optional const& getOptionalObservationValuations() const; + bool isCanonic() const; void setIsCanonic(bool newValue = true); @@ -94,11 +100,13 @@ namespace storm { // TODO: consider a bitvector based presentation (depending on our needs). std::vector observations; - uint64_t nrObservations; - bool canonicFlag = false; + boost::optional observationValuations; + + + void computeNrObservations(); }; } diff --git a/src/storm/storage/sparse/ModelComponents.h b/src/storm/storage/sparse/ModelComponents.h index f71037f1a..ca325b73c 100644 --- a/src/storm/storage/sparse/ModelComponents.h +++ b/src/storm/storage/sparse/ModelComponents.h @@ -67,6 +67,8 @@ namespace storm { // The POMDP observations boost::optional> observabilityClasses; + boost::optional observationValuations; + // Continuous time specific components (CTMCs, Markov Automata): // True iff the transition values (for Markovian choices) are interpreted as rates. bool rateTransitions; diff --git a/src/storm/storage/sparse/StateValuations.cpp b/src/storm/storage/sparse/StateValuations.cpp index 541e56b90..589a61ccd 100644 --- a/src/storm/storage/sparse/StateValuations.cpp +++ b/src/storm/storage/sparse/StateValuations.cpp @@ -10,7 +10,7 @@ namespace storm { namespace storage { namespace sparse { - StateValuations::StateValuation::StateValuation(std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues) : booleanValues(std::move(booleanValues)), integerValues(std::move(integerValues)), rationalValues(std::move(rationalValues)) { + StateValuations::StateValuation::StateValuation(std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues, std::vector&& observationLabelValues) : booleanValues(std::move(booleanValues)), integerValues(std::move(integerValues)), rationalValues(std::move(rationalValues)), observationLabelValues(std::move(observationLabelValues)) { // Intentionally left empty } @@ -20,15 +20,46 @@ namespace storm { return valuations[stateIndex]; } - StateValuations::StateValueIterator::StateValueIterator(typename std::map::const_iterator variableIt, StateValuation const* valuation) : variableIt(variableIt), valuation(valuation) { + StateValuations::StateValueIterator::StateValueIterator(typename std::map::const_iterator variableIt, + typename std::map::const_iterator labelIt, + typename std::map::const_iterator variableBegin , + typename std::map::const_iterator variableEnd, + typename std::map::const_iterator labelBegin, + typename std::map::const_iterator labelEnd, + StateValuation const* valuation) : variableIt(variableIt), labelIt(labelIt), + variableBegin(variableBegin), variableEnd(variableEnd), + labelBegin(labelBegin), labelEnd(labelEnd), valuation(valuation) { // Intentionally left empty. } - storm::expressions::Variable const& StateValuations::StateValueIterator::getVariable() const { return variableIt->first; } - bool StateValuations::StateValueIterator::isBoolean() const { return getVariable().hasBooleanType(); } - bool StateValuations::StateValueIterator::isInteger() const { return getVariable().hasIntegerType(); } - bool StateValuations::StateValueIterator::isRational() const { return getVariable().hasRationalType(); } - + bool StateValuations::StateValueIterator::isVariableAssignment() const { + return variableIt != variableEnd; + } + + bool StateValuations::StateValueIterator::isLabelAssignment() const { + return variableIt == variableEnd; + } + + storm::expressions::Variable const& StateValuations::StateValueIterator::getVariable() const { + STORM_LOG_ASSERT(isVariableAssignment(), "Does not point to a variable"); + return variableIt->first; + } + std::string const& StateValuations::StateValueIterator::getLabel() const { + STORM_LOG_ASSERT(isLabelAssignment(), "Does not point to a label"); + return labelIt->first; + } + bool StateValuations::StateValueIterator::isBoolean() const { return isVariableAssignment() && getVariable().hasBooleanType(); } + bool StateValuations::StateValueIterator::isInteger() const { return isVariableAssignment() && getVariable().hasIntegerType(); } + bool StateValuations::StateValueIterator::isRational() const { return isVariableAssignment() && getVariable().hasRationalType(); } + + std::string const& StateValuations::StateValueIterator::getName() const { + if(isVariableAssignment()) { + return getVariable().getName(); + } else { + return getLabel(); + } + } + bool StateValuations::StateValueIterator::getBooleanValue() const { STORM_LOG_ASSERT(isBoolean(), "Variable has no boolean type."); return valuation->booleanValues[variableIt->second]; @@ -38,7 +69,13 @@ namespace storm { STORM_LOG_ASSERT(isInteger(), "Variable has no integer type."); return valuation->integerValues[variableIt->second]; } - + + int64_t StateValuations::StateValueIterator::getLabelValue() const { + STORM_LOG_ASSERT(isLabelAssignment(), "Not a label assignment"); + STORM_LOG_ASSERT(labelIt->second < valuation->observationLabelValues.size(), "Label index " << labelIt->second << " larger than number of labels " << valuation->observationLabelValues.size()); + return valuation->observationLabelValues[labelIt->second]; + } + storm::RationalNumber StateValuations::StateValueIterator::getRationalValue() const { STORM_LOG_ASSERT(isRational(), "Variable has no rational type."); return valuation->rationalValues[variableIt->second]; @@ -46,33 +83,41 @@ namespace storm { bool StateValuations::StateValueIterator::operator==(StateValueIterator const& other) { STORM_LOG_ASSERT(valuation == valuation, "Comparing iterators for different states"); - return variableIt == other.variableIt; + return variableIt == other.variableIt && labelIt == other.labelIt; } bool StateValuations::StateValueIterator::operator!=(StateValueIterator const& other) { - STORM_LOG_ASSERT(valuation == valuation, "Comparing iterators for different states"); - return variableIt != other.variableIt; + return !(*this == other); } typename StateValuations::StateValueIterator& StateValuations::StateValueIterator::operator++() { - ++variableIt; + if(variableIt != variableEnd ) { + ++variableIt; + } else { + ++labelIt; + } + return *this; } typename StateValuations::StateValueIterator& StateValuations::StateValueIterator::operator--() { - --variableIt; + if (labelIt != labelBegin) { + --labelIt; + } else { + --variableIt; + } return *this; } - StateValuations::StateValueIteratorRange::StateValueIteratorRange(std::map const& variableMap, StateValuation const* valuation) : variableMap(variableMap), valuation(valuation) { + StateValuations::StateValueIteratorRange::StateValueIteratorRange(std::map const& variableMap, std::map const& labelMap, StateValuation const* valuation) : variableMap(variableMap), labelMap(labelMap), valuation(valuation) { // Intentionally left empty. } StateValuations::StateValueIterator StateValuations::StateValueIteratorRange::begin() const { - return StateValueIterator(variableMap.cbegin(), valuation); + return StateValueIterator(variableMap.cbegin(), labelMap.cbegin(), variableMap.cbegin(), variableMap.cend(), labelMap.cbegin(), labelMap.cend(), valuation); } StateValuations::StateValueIterator StateValuations::StateValueIteratorRange::end() const { - return StateValueIterator(variableMap.cend(), valuation); + return StateValueIterator(variableMap.cend(), labelMap.cend(), variableMap.cbegin(), variableMap.cend(), labelMap.cbegin(), labelMap.cend(), valuation); } bool StateValuations::getBooleanValue(storm::storage::sparse::state_type const& stateIndex, storm::expressions::Variable const& booleanVariable) const { @@ -94,8 +139,8 @@ namespace storm { } bool StateValuations::isEmpty(storm::storage::sparse::state_type const& stateIndex) const { - auto const& valuation = getValuation(stateIndex); - return valuation.booleanValues.empty() && valuation.integerValues.empty() && valuation.rationalValues.empty(); + auto const& valuation = valuations[stateIndex]; // Do not use getValuations, as that is only valid after adding stuff. + return valuation.booleanValues.empty() && valuation.integerValues.empty() && valuation.rationalValues.empty() && valuation.observationLabelValues.empty(); } std::string StateValuations::toString(storm::storage::sparse::state_type const& stateIndex, bool pretty, boost::optional> const& selectedVariables) const { @@ -115,11 +160,13 @@ namespace storm { if (valIt.isBoolean() && !valIt.getBooleanValue()) { stream << "!"; } - stream << valIt.getVariable().getName(); + stream << valIt.getName(); if (valIt.isInteger()) { stream << "=" << valIt.getIntegerValue(); } else if (valIt.isRational()) { stream << "=" << valIt.getRationalValue(); + } else if (valIt.isLabelAssignment()) { + stream << "=" << valIt.getLabelValue(); } else { STORM_LOG_THROW(valIt.isBoolean(), storm::exceptions::InvalidTypeException, "Unexpected variable type."); } @@ -130,6 +177,8 @@ namespace storm { stream << valIt.getIntegerValue(); } else if (valIt.isRational()) { stream << valIt.getRationalValue(); + } else if (valIt.isLabelAssignment()) { + stream << valIt.getLabelValue(); } } assignments.push_back(stream.str()); @@ -161,12 +210,12 @@ namespace storm { } if (valIt.isBoolean()) { - result[valIt.getVariable().getName()] = valIt.getBooleanValue(); + result[valIt.getName()] = valIt.getBooleanValue(); } else if (valIt.isInteger()) { - result[valIt.getVariable().getName()] = valIt.getIntegerValue(); + result[valIt.getName()] = valIt.getIntegerValue(); } else { STORM_LOG_ASSERT(valIt.isRational(), "Unexpected variable type."); - result[valIt.getVariable().getName()] = valIt.getRationalValue(); + result[valIt.getName()] = valIt.getRationalValue(); } if (selectedVariables) { @@ -220,7 +269,7 @@ namespace storm { typename StateValuations::StateValueIteratorRange StateValuations::at(state_type const& state) const { STORM_LOG_ASSERT(state < getNumberOfStates(), "Invalid state index."); - return StateValueIteratorRange({variableToIndexMap, &(valuations[state])}); + return StateValueIteratorRange({variableToIndexMap, observationLabels, &(valuations[state])}); } uint_fast64_t StateValuations::getNumberOfStates() const { @@ -248,7 +297,7 @@ namespace storm { return StateValuations(variableToIndexMap, std::move(selectedValuations)); } - StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0) { + StateValuationsBuilder::StateValuationsBuilder() : booleanVarCount(0), integerVarCount(0), rationalVarCount(0), labelCount(0) { // Intentionally left empty. } @@ -265,24 +314,41 @@ namespace storm { currentStateValuations.variableToIndexMap[variable] = rationalVarCount++; } } + + void StateValuationsBuilder::addObservationLabel(const std::string &label) { + currentStateValuations.observationLabels[label] = labelCount++; + } - void StateValuationsBuilder::addState(storm::storage::sparse::state_type const& state, std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues) { + void StateValuationsBuilder::addState(storm::storage::sparse::state_type const& state, std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues,std::vector&& observationLabelValues) { if (state > currentStateValuations.valuations.size()) { currentStateValuations.valuations.resize(state); } if (state == currentStateValuations.valuations.size()) { - currentStateValuations.valuations.emplace_back(std::move(booleanValues), std::move(integerValues), std::move(rationalValues)); + currentStateValuations.valuations.emplace_back(std::move(booleanValues), std::move(integerValues), std::move(rationalValues), std::move(observationLabelValues)); } else { STORM_LOG_ASSERT(currentStateValuations.isEmpty(state), "Adding a valuation to the same state multiple times."); - currentStateValuations.valuations[state] = typename StateValuations::StateValuation(std::move(booleanValues), std::move(integerValues), std::move(rationalValues)); + currentStateValuations.valuations[state] = typename StateValuations::StateValuation(std::move(booleanValues), std::move(integerValues), std::move(rationalValues), std::move(observationLabelValues)); } } + + uint64_t StateValuationsBuilder::getBooleanVarCount() const { + return booleanVarCount; + } + + uint64_t StateValuationsBuilder::getIntegerVarCount() const { + return integerVarCount; + } + + uint64_t StateValuationsBuilder::getLabelCount() const { + return labelCount; + } StateValuations StateValuationsBuilder::build(std::size_t totalStateCount) { return std::move(currentStateValuations); booleanVarCount = 0; integerVarCount = 0; rationalVarCount = 0; + labelCount = 0; } } } diff --git a/src/storm/storage/sparse/StateValuations.h b/src/storm/storage/sparse/StateValuations.h index 07027d52b..58925b4c4 100644 --- a/src/storm/storage/sparse/StateValuations.h +++ b/src/storm/storage/sparse/StateValuations.h @@ -26,45 +26,65 @@ namespace storm { public: friend class StateValuations; StateValuation() = default; - StateValuation(std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues); + StateValuation(std::vector&& booleanValues, std::vector&& integerValues, std::vector&& rationalValues, std::vector&& observationLabelValues = {}); private: std::vector booleanValues; std::vector integerValues; std::vector rationalValues; + std::vector observationLabelValues; }; class StateValueIterator { public: - StateValueIterator(typename std::map::const_iterator variableIt, StateValuation const* valuation); + StateValueIterator(typename std::map::const_iterator variableIt, + typename std::map::const_iterator labelIt, + typename std::map::const_iterator variableBegin , + typename std::map::const_iterator variableEnd, + typename std::map::const_iterator labelBegin, + typename std::map::const_iterator labelEnd, + StateValuation const* valuation); bool operator==(StateValueIterator const& other); bool operator!=(StateValueIterator const& other); StateValueIterator& operator++(); StateValueIterator& operator--(); - + + bool isVariableAssignment() const; + bool isLabelAssignment() const; storm::expressions::Variable const& getVariable() const; + std::string const& getLabel() const; bool isBoolean() const; bool isInteger() const; bool isRational() const; - + + std::string const& getName() const; + // These shall only be called if the variable has the correct type bool getBooleanValue() const; int64_t getIntegerValue() const; storm::RationalNumber getRationalValue() const; + int64_t getLabelValue() const; private: typename std::map::const_iterator variableIt; + typename std::map::const_iterator labelIt; + typename std::map::const_iterator variableBegin; + typename std::map::const_iterator variableEnd; + typename std::map::const_iterator labelBegin; + typename std::map::const_iterator labelEnd; + StateValuation const* const valuation; }; class StateValueIteratorRange { public: - StateValueIteratorRange(std::map const& variableMap, StateValuation const* valuation); + StateValueIteratorRange(std::map const& variableMap, std::map const& labelMap, StateValuation const* valuation); StateValueIterator begin() const; StateValueIterator end() const; private: std::map const& variableMap; + std::map const& labelMap; StateValuation const* const valuation; }; @@ -117,6 +137,7 @@ namespace storm { StateValuation const& getValuation(storm::storage::sparse::state_type const& stateIndex) const; std::map variableToIndexMap; + std::map observationLabels; // A mapping from state indices to their variable valuations. std::vector valuations; @@ -127,28 +148,35 @@ namespace storm { StateValuationsBuilder(); /*! Adds a new variable to keep track of for the state valuations. - *! All variables need to be added before adding new states. + * All variables need to be added before adding new states. */ void addVariable(storm::expressions::Variable const& variable); - + + void addObservationLabel(std::string const& label); + /*! * Adds a new state. * The variable values have to be given in the same order as the variables have been added. * The number of given variable values for each type needs to match the number of added variables. * After calling this method, no more variables should be added. */ - void addState(storm::storage::sparse::state_type const& state, std::vector&& booleanValues = {}, std::vector&& integerValues = {}, std::vector&& rationalValues = {}); + void addState(storm::storage::sparse::state_type const& state, std::vector&& booleanValues = {}, std::vector&& integerValues = {}, std::vector&& rationalValues = {}, std::vector&& observationLabelValues = {}); /*! * Creates the finalized state valuations object. */ StateValuations build(std::size_t totalStateCount); + uint64_t getBooleanVarCount() const; + uint64_t getIntegerVarCount() const; + uint64_t getLabelCount() const; + private: StateValuations currentStateValuations; uint64_t booleanVarCount; uint64_t integerVarCount; uint64_t rationalVarCount; + uint64_t labelCount; }; } }