From 4a7ea35959fb8a59e1fe0f9299d72bdd939e2909 Mon Sep 17 00:00:00 2001 From: Sebastian Junges <sebastian.junges@gmail.com> Date: Mon, 28 Dec 2020 22:43:54 -0800 Subject: [PATCH] first version for action mask callbacks in explicit generator --- src/storm/api/builder.h | 5 +-- src/storm/generator/NextStateGenerator.cpp | 34 +++++++++++++++++-- src/storm/generator/NextStateGenerator.h | 32 +++++++++++++++-- .../generator/PrismNextStateGenerator.cpp | 14 ++++++-- src/storm/generator/PrismNextStateGenerator.h | 4 +-- 5 files changed, 77 insertions(+), 12 deletions(-) diff --git a/src/storm/api/builder.h b/src/storm/api/builder.h index fbe838c5d..afc73d15e 100644 --- a/src/storm/api/builder.h +++ b/src/storm/api/builder.h @@ -85,11 +85,12 @@ namespace storm { * @return A builder */ template<typename ValueType> - storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options) { + storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options, std::shared_ptr<storm::generator::ActionMask<ValueType>> actionMask = nullptr) { std::shared_ptr<storm::generator::NextStateGenerator<ValueType, uint32_t>> generator; if (model.isPrismProgram()) { - generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options); + generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options, actionMask); } else if (model.isJaniModel()) { + STORM_LOG_THROW(actionMask == nullptr, storm::exceptions::NotSupportedException, "Action masks for JANI are not yet supported"); generator = std::make_shared<storm::generator::JaniNextStateGenerator<ValueType, uint32_t>>(model.asJaniModel(), options); } else { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot build sparse model from this symbolic model description."); diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp index e958f86b8..be595a4f0 100644 --- a/src/storm/generator/NextStateGenerator.cpp +++ b/src/storm/generator/NextStateGenerator.cpp @@ -16,9 +16,23 @@ namespace storm { namespace generator { - + + template<typename ValueType, typename StateType> + StateValuationFunctionMask<ValueType,StateType>::StateValuationFunctionMask(std::function<bool(storm::expressions::SimpleValuation const&, uint64_t)> const& f) + : func(f) + { + // Intentionally left empty + } + + template<typename ValueType, typename StateType> + bool StateValuationFunctionMask<ValueType,StateType>::query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) { + auto val = generator.currentStateToSimpleValuation(); + bool res = func(val, actionIndex); + return res; + } + template<typename ValueType, typename StateType> - NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr) { + NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr), actionMask(mask) { if(variableInformation.hasOutOfBoundsBit()) { outOfBoundsState = createOutOfBoundsState(variableInformation); } @@ -28,7 +42,7 @@ namespace storm { } template<typename ValueType, typename StateType> - NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr) { + NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr), actionMask(mask) { if(variableInformation.hasOutOfBoundsBit()) { outOfBoundsState = createOutOfBoundsState(variableInformation); } @@ -270,6 +284,12 @@ namespace storm { return result; } + template<typename ValueType, typename StateType> + storm::expressions::SimpleValuation NextStateGenerator<ValueType, StateType>::currentStateToSimpleValuation() const { + return unpackStateIntoValuation(*state, variableInformation, *expressionManager); + } + + template<typename ValueType, typename StateType> void NextStateGenerator<ValueType, StateType>::extendStateInformation(storm::json<ValueType>&, bool) const { // Intentionally left empty. @@ -301,7 +321,15 @@ namespace storm { template class NextStateGenerator<double>; + template class ActionMask<double>; + template class StateValuationFunctionMask<double>; + #ifdef STORM_HAVE_CARL + template class ActionMask<storm::RationalNumber>; + template class StateValuationFunctionMask<storm::RationalNumber>; + template class ActionMask<storm::RationalFunction>; + template class StateValuationFunctionMask<storm::RationalFunction>; + template class NextStateGenerator<storm::RationalNumber>; template class NextStateGenerator<storm::RationalFunction>; #endif diff --git a/src/storm/generator/NextStateGenerator.h b/src/storm/generator/NextStateGenerator.h index 5ef4a639a..262d3b29f 100644 --- a/src/storm/generator/NextStateGenerator.h +++ b/src/storm/generator/NextStateGenerator.h @@ -11,6 +11,7 @@ #include "storm/storage/expressions/ExpressionEvaluator.h" #include "storm/storage/sparse/ChoiceOrigins.h" #include "storm/storage/sparse/StateValuations.h" +#include "storm/storage/expressions/SimpleValuation.h" #include "storm/builder/BuilderOptions.h" #include "storm/builder/RewardModelInformation.h" @@ -32,19 +33,41 @@ namespace storm { MA, POMDP }; - + + template<typename ValueType, typename StateType = uint32_t> + class NextStateGenerator; + + template<typename ValueType, typename StateType = uint32_t> + class ActionMask { + public: + virtual ~ActionMask() = default; + virtual bool query(storm::generator::NextStateGenerator<ValueType, StateType> const &generator, uint64_t actionIndex) = 0; + }; + template<typename ValueType, typename StateType = uint32_t> + class StateValuationFunctionMask : public ActionMask<ValueType,StateType> { + public: + StateValuationFunctionMask(std::function<bool (storm::expressions::SimpleValuation const&, uint64_t)> const& f); + virtual ~StateValuationFunctionMask() = default; + bool query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) override; + private: + std::function<bool(storm::expressions::SimpleValuation, uint64_t)> func; + }; + + + + template<typename ValueType, typename StateType> class NextStateGenerator { public: typedef std::function<StateType (CompressedState const&)> StateToIdCallback; - NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options); + NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr); /*! * Creates a new next state generator. This version of the constructor default-constructs the variable information. * Hence, the subclass is responsible for suitably initializing it in its constructor. */ - NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options); + NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr); virtual ~NextStateGenerator() = default; @@ -73,6 +96,7 @@ namespace storm { std::string stateToString(CompressedState const& state) const; storm::json<ValueType> currentStateToJson(bool onlyObservable = false) const; + storm::expressions::SimpleValuation currentStateToSimpleValuation() const; uint32_t observabilityClass(CompressedState const& state) const; @@ -147,6 +171,8 @@ namespace storm { /// A map that stores the indices of states with overlapping guards. boost::optional<std::vector<uint64_t>> overlappingGuardStates; + std::shared_ptr<ActionMask<ValueType,StateType>> actionMask; + }; } } diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp index e9df107c5..02ca39e4b 100644 --- a/src/storm/generator/PrismNextStateGenerator.cpp +++ b/src/storm/generator/PrismNextStateGenerator.cpp @@ -22,12 +22,12 @@ namespace storm { namespace generator { template<typename ValueType, typename StateType> - PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, false) { + PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, mask, false) { // Intentionally left empty. } template<typename ValueType, typename StateType> - PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options), program(program), rewardModels(), hasStateActionRewards(false) { + PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options, mask), program(program), rewardModels(), hasStateActionRewards(false) { STORM_LOG_TRACE("Creating next-state generator for PRISM program: " << program); STORM_LOG_THROW(!this->program.specifiesSystemComposition(), storm::exceptions::WrongFormatException, "The explicit next-state generator currently does not support custom system compositions."); @@ -536,6 +536,11 @@ namespace storm { continue; } } + if (this->actionMask != nullptr) { + if (!this->actionMask->query(*this, command.getActionIndex())) { + continue; + } + } // Skip the command, if it is not enabled. if (!this->evaluator->asBool(command.getGuardExpression())) { @@ -616,6 +621,11 @@ namespace storm { void PrismNextStateGenerator<ValueType, StateType>::addLabeledChoices(std::vector<Choice<ValueType>>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) { for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) { + if (this->actionMask != nullptr) { + if (!this->actionMask->query(*this, actionIndex)) { + continue; + } + } boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex, commandFilter); // Only process this action label, if there is at least one feasible solution. diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h index 5f06bffa5..447db6525 100644 --- a/src/storm/generator/PrismNextStateGenerator.h +++ b/src/storm/generator/PrismNextStateGenerator.h @@ -23,7 +23,7 @@ namespace storm { typedef storm::storage::FlatSet<uint_fast64_t> CommandSet; enum class CommandFilter {All, Markovian, Probabilistic}; - PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions()); + PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions(), std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr); /*! * A quick check to detect whether the given model is not supported. @@ -55,7 +55,7 @@ namespace storm { * being called. The last argument is only present to distinguish the signature of this constructor from the * public one. */ - PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool flag); + PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const&, bool flag); /*! * Applies an update to the state currently loaded into the evaluator and applies the resulting values to