diff --git a/src/storm/generator/CompressedState.cpp b/src/storm/generator/CompressedState.cpp index 0bdda94de..b623dab72 100644 --- a/src/storm/generator/CompressedState.cpp +++ b/src/storm/generator/CompressedState.cpp @@ -69,9 +69,13 @@ namespace storm { return result; } - uint32_t unpackStateToObservabilityClass(CompressedState const& state, std::unordered_map& observabilityMap, storm::storage::BitVector const& mask) { + uint32_t unpackStateToObservabilityClass(CompressedState const& state, storm::storage::BitVector const& observationVector, std::unordered_map& observabilityMap, storm::storage::BitVector const& mask) { STORM_LOG_ASSERT(state.size() == mask.size(), "Mask should be as long as state."); storm::storage::BitVector observeClass = state & mask; + if (observationVector.size() != 0) { + observeClass.concat(observationVector); + } + auto it = observabilityMap.find(observeClass); if (it != observabilityMap.end()) { return it->second; diff --git a/src/storm/generator/CompressedState.h b/src/storm/generator/CompressedState.h index 9a12d69c3..302230273 100644 --- a/src/storm/generator/CompressedState.h +++ b/src/storm/generator/CompressedState.h @@ -50,7 +50,7 @@ namespace storm { * @param mask * @return */ - uint32_t unpackStateToObservabilityClass(CompressedState const& state, std::unordered_map& observabilityMap, storm::storage::BitVector const& mask); + uint32_t unpackStateToObservabilityClass(CompressedState const& state, storm::storage::BitVector const& observationVector, std::unordered_map& observabilityMap, storm::storage::BitVector const& mask); /*! * * @param varInfo diff --git a/src/storm/generator/JaniNextStateGenerator.cpp b/src/storm/generator/JaniNextStateGenerator.cpp index 4f12d8d32..f2b05f74d 100644 --- a/src/storm/generator/JaniNextStateGenerator.cpp +++ b/src/storm/generator/JaniNextStateGenerator.cpp @@ -1119,7 +1119,14 @@ namespace storm { return std::make_shared(std::make_shared(model), std::move(identifiers), std::move(identifierToEdgeIndexSetMapping)); } - + + template + storm::storage::BitVector JaniNextStateGenerator::evaluateObservationLabels(CompressedState const& state) const { + STORM_LOG_WARN("There are no observation labels in JANI currenty"); + return storm::storage::BitVector(0); + }; + + template void JaniNextStateGenerator::checkValid() const { // If the program still contains undefined constants and we are not in a parametric setting, assemble an appropriate error message. diff --git a/src/storm/generator/JaniNextStateGenerator.h b/src/storm/generator/JaniNextStateGenerator.h index 1979acf05..790395faf 100644 --- a/src/storm/generator/JaniNextStateGenerator.h +++ b/src/storm/generator/JaniNextStateGenerator.h @@ -91,7 +91,16 @@ namespace storm { * @return The resulting state. */ void applyTransientUpdate(TransientVariableValuation& transientValuation, storm::jani::detail::ConstAssignments const& transientAssignments, storm::expressions::ExpressionEvaluator const& expressionEvaluator); - + + /** + * Required method to overload, but currently throws an error as POMDPs are not yet specified in JANI. + * Furthermore, it might be that these observation labels will not be used and that one uses transient variables instead. + * + * @param state + * @return + */ + virtual storm::storage::BitVector evaluateObservationLabels(CompressedState const& state) const override; + /*! * Retrieves all choices possible from the given state. * diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp index bbc89d7af..fd4578235 100644 --- a/src/storm/generator/NextStateGenerator.cpp +++ b/src/storm/generator/NextStateGenerator.cpp @@ -183,8 +183,7 @@ namespace storm { if (this->mask.size() == 0) { this->mask = computeObservabilityMask(variableInformation); } - - return unpackStateToObservabilityClass(state, observabilityMap, mask); + return unpackStateToObservabilityClass(state, evaluateObservationLabels(state), observabilityMap, mask); } template diff --git a/src/storm/generator/NextStateGenerator.h b/src/storm/generator/NextStateGenerator.h index ae95c3a68..fb72de23f 100644 --- a/src/storm/generator/NextStateGenerator.h +++ b/src/storm/generator/NextStateGenerator.h @@ -68,6 +68,8 @@ namespace storm { virtual storm::models::sparse::StateLabeling label(storm::storage::sparse::StateStorage const& stateStorage, std::vector const& initialStateIndices = {}, std::vector const& deadlockStateIndices = {}) = 0; NextStateGeneratorOptions const& getOptions() const; + + virtual std::shared_ptr generateChoiceOrigins(std::vector& dataForChoiceOrigins) const; @@ -83,7 +85,9 @@ namespace storm { * Creates the state labeling for the given states using the provided labels and expressions. */ storm::models::sparse::StateLabeling label(storm::storage::sparse::StateStorage const& stateStorage, std::vector const& initialStateIndices, std::vector const& deadlockStateIndices, std::vector> labelsAndExpressions); - + + virtual storm::storage::BitVector evaluateObservationLabels(CompressedState const& state) const =0; + void postprocess(StateBehavior& result); /// The options to be used for next-state generation. diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp index c368a976d..68fda0abd 100644 --- a/src/storm/generator/PrismNextStateGenerator.cpp +++ b/src/storm/generator/PrismNextStateGenerator.cpp @@ -682,7 +682,20 @@ namespace storm { return NextStateGenerator::label(stateStorage, initialStateIndices, deadlockStateIndices, labels); } - + + template + storm::storage::BitVector PrismNextStateGenerator::evaluateObservationLabels(CompressedState const& state) const { + // TODO consider to avoid reloading by computing these bitvectors in an earlier build stage + unpackStateIntoEvaluator(state, this->variableInformation, *this->evaluator); + + storm::storage::BitVector result(program.getNumberOfObservationLabels() * 64); + for (uint64_t i = 0; i < program.getNumberOfObservationLabels(); ++i) { + result.setFromInt(64*i,64,this->evaluator->asInt(program.getObservationLabels()[i].getStatePredicateExpression())); + } + return result; + }; + + template std::size_t PrismNextStateGenerator::getNumberOfRewardModels() const { return rewardModels.size(); diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h index 4b643e48b..2b9fdc678 100644 --- a/src/storm/generator/PrismNextStateGenerator.h +++ b/src/storm/generator/PrismNextStateGenerator.h @@ -92,7 +92,13 @@ namespace storm { * @return The labeled choices of the state. */ std::vector> getLabeledChoices(CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter = CommandFilter::All); - + + + /*! + * Evaluate observation labels + */ + storm::storage::BitVector evaluateObservationLabels(CompressedState const& state) const override; + /*! * A recursive helper function to generate a synchronziing distribution. */ diff --git a/src/storm/storage/sparse/ModelComponents.h b/src/storm/storage/sparse/ModelComponents.h index 59d538f63..f71037f1a 100644 --- a/src/storm/storage/sparse/ModelComponents.h +++ b/src/storm/storage/sparse/ModelComponents.h @@ -63,21 +63,19 @@ namespace storm { // stores for each choice from which parts of the input model description it originates boost::optional> choiceOrigins; + // POMDP specific components + // The POMDP observations boost::optional> observabilityClasses; + // Continuous time specific components (CTMCs, Markov Automata): - // True iff the transition values (for Markovian choices) are interpreted as rates. bool rateTransitions; - // The exit rate for each state. Must be given for CTMCs and MAs, if rateTransitions is false. Otherwise, it is optional. boost::optional> exitRates; - // A vector that stores which states are markovian (only for Markov Automata). boost::optional markovianStates; - // Stochastic two player game specific components: - // The matrix of player 1 choices (needed for stochastic two player games boost::optional> player1Matrix; };