From 4a7ea35959fb8a59e1fe0f9299d72bdd939e2909 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Mon, 28 Dec 2020 22:43:54 -0800 Subject: [PATCH] first version for action mask callbacks in explicit generator --- src/storm/api/builder.h | 5 +-- src/storm/generator/NextStateGenerator.cpp | 34 +++++++++++++++++-- src/storm/generator/NextStateGenerator.h | 32 +++++++++++++++-- .../generator/PrismNextStateGenerator.cpp | 14 ++++++-- src/storm/generator/PrismNextStateGenerator.h | 4 +-- 5 files changed, 77 insertions(+), 12 deletions(-) diff --git a/src/storm/api/builder.h b/src/storm/api/builder.h index fbe838c5d..afc73d15e 100644 --- a/src/storm/api/builder.h +++ b/src/storm/api/builder.h @@ -85,11 +85,12 @@ namespace storm { * @return A builder */ template - storm::builder::ExplicitModelBuilder makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options) { + storm::builder::ExplicitModelBuilder makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options, std::shared_ptr> actionMask = nullptr) { std::shared_ptr> generator; if (model.isPrismProgram()) { - generator = std::make_shared>(model.asPrismProgram(), options); + generator = std::make_shared>(model.asPrismProgram(), options, actionMask); } else if (model.isJaniModel()) { + STORM_LOG_THROW(actionMask == nullptr, storm::exceptions::NotSupportedException, "Action masks for JANI are not yet supported"); generator = std::make_shared>(model.asJaniModel(), options); } else { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot build sparse model from this symbolic model description."); diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp index e958f86b8..be595a4f0 100644 --- a/src/storm/generator/NextStateGenerator.cpp +++ b/src/storm/generator/NextStateGenerator.cpp @@ -16,9 +16,23 @@ namespace storm { namespace generator { - + + template + StateValuationFunctionMask::StateValuationFunctionMask(std::function const& f) + : func(f) + { + // Intentionally left empty + } + + template + bool StateValuationFunctionMask::query(storm::generator::NextStateGenerator const& generator, uint64_t actionIndex) { + auto val = generator.currentStateToSimpleValuation(); + bool res = func(val, actionIndex); + return res; + } + template - NextStateGenerator::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr) { + NextStateGenerator::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr), actionMask(mask) { if(variableInformation.hasOutOfBoundsBit()) { outOfBoundsState = createOutOfBoundsState(variableInformation); } @@ -28,7 +42,7 @@ namespace storm { } template - NextStateGenerator::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr) { + NextStateGenerator::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr), actionMask(mask) { if(variableInformation.hasOutOfBoundsBit()) { outOfBoundsState = createOutOfBoundsState(variableInformation); } @@ -270,6 +284,12 @@ namespace storm { return result; } + template + storm::expressions::SimpleValuation NextStateGenerator::currentStateToSimpleValuation() const { + return unpackStateIntoValuation(*state, variableInformation, *expressionManager); + } + + template void NextStateGenerator::extendStateInformation(storm::json&, bool) const { // Intentionally left empty. @@ -301,7 +321,15 @@ namespace storm { template class NextStateGenerator; + template class ActionMask; + template class StateValuationFunctionMask; + #ifdef STORM_HAVE_CARL + template class ActionMask; + template class StateValuationFunctionMask; + template class ActionMask; + template class StateValuationFunctionMask; + template class NextStateGenerator; template class NextStateGenerator; #endif diff --git a/src/storm/generator/NextStateGenerator.h b/src/storm/generator/NextStateGenerator.h index 5ef4a639a..262d3b29f 100644 --- a/src/storm/generator/NextStateGenerator.h +++ b/src/storm/generator/NextStateGenerator.h @@ -11,6 +11,7 @@ #include "storm/storage/expressions/ExpressionEvaluator.h" #include "storm/storage/sparse/ChoiceOrigins.h" #include "storm/storage/sparse/StateValuations.h" +#include "storm/storage/expressions/SimpleValuation.h" #include "storm/builder/BuilderOptions.h" #include "storm/builder/RewardModelInformation.h" @@ -32,19 +33,41 @@ namespace storm { MA, POMDP }; - + + template + class NextStateGenerator; + + template + class ActionMask { + public: + virtual ~ActionMask() = default; + virtual bool query(storm::generator::NextStateGenerator const &generator, uint64_t actionIndex) = 0; + }; + template + class StateValuationFunctionMask : public ActionMask { + public: + StateValuationFunctionMask(std::function const& f); + virtual ~StateValuationFunctionMask() = default; + bool query(storm::generator::NextStateGenerator const& generator, uint64_t actionIndex) override; + private: + std::function func; + }; + + + + template class NextStateGenerator { public: typedef std::function StateToIdCallback; - NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options); + NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr> const& = nullptr); /*! * Creates a new next state generator. This version of the constructor default-constructs the variable information. * Hence, the subclass is responsible for suitably initializing it in its constructor. */ - NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options); + NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr> const& = nullptr); virtual ~NextStateGenerator() = default; @@ -73,6 +96,7 @@ namespace storm { std::string stateToString(CompressedState const& state) const; storm::json currentStateToJson(bool onlyObservable = false) const; + storm::expressions::SimpleValuation currentStateToSimpleValuation() const; uint32_t observabilityClass(CompressedState const& state) const; @@ -147,6 +171,8 @@ namespace storm { /// A map that stores the indices of states with overlapping guards. boost::optional> overlappingGuardStates; + std::shared_ptr> actionMask; + }; } } diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp index e9df107c5..02ca39e4b 100644 --- a/src/storm/generator/PrismNextStateGenerator.cpp +++ b/src/storm/generator/PrismNextStateGenerator.cpp @@ -22,12 +22,12 @@ namespace storm { namespace generator { template - PrismNextStateGenerator::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options) : PrismNextStateGenerator(program.substituteConstantsFormulas(), options, false) { + PrismNextStateGenerator::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr> const& mask) : PrismNextStateGenerator(program.substituteConstantsFormulas(), options, mask, false) { // Intentionally left empty. } template - PrismNextStateGenerator::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool) : NextStateGenerator(program.getManager(), options), program(program), rewardModels(), hasStateActionRewards(false) { + PrismNextStateGenerator::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr> const& mask, bool) : NextStateGenerator(program.getManager(), options, mask), program(program), rewardModels(), hasStateActionRewards(false) { STORM_LOG_TRACE("Creating next-state generator for PRISM program: " << program); STORM_LOG_THROW(!this->program.specifiesSystemComposition(), storm::exceptions::WrongFormatException, "The explicit next-state generator currently does not support custom system compositions."); @@ -536,6 +536,11 @@ namespace storm { continue; } } + if (this->actionMask != nullptr) { + if (!this->actionMask->query(*this, command.getActionIndex())) { + continue; + } + } // Skip the command, if it is not enabled. if (!this->evaluator->asBool(command.getGuardExpression())) { @@ -616,6 +621,11 @@ namespace storm { void PrismNextStateGenerator::addLabeledChoices(std::vector>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) { for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) { + if (this->actionMask != nullptr) { + if (!this->actionMask->query(*this, actionIndex)) { + continue; + } + } boost::optional>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex, commandFilter); // Only process this action label, if there is at least one feasible solution. diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h index 5f06bffa5..447db6525 100644 --- a/src/storm/generator/PrismNextStateGenerator.h +++ b/src/storm/generator/PrismNextStateGenerator.h @@ -23,7 +23,7 @@ namespace storm { typedef storm::storage::FlatSet CommandSet; enum class CommandFilter {All, Markovian, Probabilistic}; - PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions()); + PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions(), std::shared_ptr> const& = nullptr); /*! * A quick check to detect whether the given model is not supported. @@ -55,7 +55,7 @@ namespace storm { * being called. The last argument is only present to distinguish the signature of this constructor from the * public one. */ - PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool flag); + PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr> const&, bool flag); /*! * Applies an update to the state currently loaded into the evaluator and applies the resulting values to