From 4a7ea35959fb8a59e1fe0f9299d72bdd939e2909 Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Mon, 28 Dec 2020 22:43:54 -0800
Subject: [PATCH] first version for action mask callbacks in explicit generator

---
 src/storm/api/builder.h                       |  5 +--
 src/storm/generator/NextStateGenerator.cpp    | 34 +++++++++++++++++--
 src/storm/generator/NextStateGenerator.h      | 32 +++++++++++++++--
 .../generator/PrismNextStateGenerator.cpp     | 14 ++++++--
 src/storm/generator/PrismNextStateGenerator.h |  4 +--
 5 files changed, 77 insertions(+), 12 deletions(-)

diff --git a/src/storm/api/builder.h b/src/storm/api/builder.h
index fbe838c5d..afc73d15e 100644
--- a/src/storm/api/builder.h
+++ b/src/storm/api/builder.h
@@ -85,11 +85,12 @@ namespace storm {
          * @return A builder
          */
         template<typename ValueType>
-        storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options) {
+        storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options, std::shared_ptr<storm::generator::ActionMask<ValueType>> actionMask = nullptr) {
             std::shared_ptr<storm::generator::NextStateGenerator<ValueType, uint32_t>> generator;
             if (model.isPrismProgram()) {
-                generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options);
+                generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options, actionMask);
             } else if (model.isJaniModel()) {
+                STORM_LOG_THROW(actionMask == nullptr, storm::exceptions::NotSupportedException, "Action masks for JANI are not yet supported");
                 generator = std::make_shared<storm::generator::JaniNextStateGenerator<ValueType, uint32_t>>(model.asJaniModel(), options);
             } else {
                 STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot build sparse model from this symbolic model description.");
diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp
index e958f86b8..be595a4f0 100644
--- a/src/storm/generator/NextStateGenerator.cpp
+++ b/src/storm/generator/NextStateGenerator.cpp
@@ -16,9 +16,23 @@
 
 namespace storm {
     namespace generator {
-                    
+
+        template<typename ValueType, typename StateType>
+        StateValuationFunctionMask<ValueType,StateType>::StateValuationFunctionMask(std::function<bool(storm::expressions::SimpleValuation const&, uint64_t)> const& f)
+        : func(f)
+        {
+            // Intentionally left empty
+        }
+
+        template<typename ValueType, typename StateType>
+        bool StateValuationFunctionMask<ValueType,StateType>::query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) {
+            auto val = generator.currentStateToSimpleValuation();
+            bool res =  func(val, actionIndex);
+            return res;
+        }
+
         template<typename ValueType, typename StateType>
-        NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr) {
+        NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr), actionMask(mask) {
             if(variableInformation.hasOutOfBoundsBit()) {
                 outOfBoundsState = createOutOfBoundsState(variableInformation);
             }
@@ -28,7 +42,7 @@ namespace storm {
         }
         
         template<typename ValueType, typename StateType>
-        NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr) {
+        NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr), actionMask(mask)  {
             if(variableInformation.hasOutOfBoundsBit()) {
                 outOfBoundsState = createOutOfBoundsState(variableInformation);
             }
@@ -270,6 +284,12 @@ namespace storm {
             return result;
         }
 
+        template<typename ValueType, typename StateType>
+        storm::expressions::SimpleValuation NextStateGenerator<ValueType, StateType>::currentStateToSimpleValuation() const {
+            return unpackStateIntoValuation(*state, variableInformation, *expressionManager);
+        }
+
+
         template<typename ValueType, typename StateType>
         void NextStateGenerator<ValueType, StateType>::extendStateInformation(storm::json<ValueType>&, bool) const {
             // Intentionally left empty.
@@ -301,7 +321,15 @@ namespace storm {
 
         template class NextStateGenerator<double>;
 
+        template class ActionMask<double>;
+        template class StateValuationFunctionMask<double>;
+
 #ifdef STORM_HAVE_CARL
+        template class ActionMask<storm::RationalNumber>;
+        template class StateValuationFunctionMask<storm::RationalNumber>;
+        template class ActionMask<storm::RationalFunction>;
+        template class StateValuationFunctionMask<storm::RationalFunction>;
+
         template class NextStateGenerator<storm::RationalNumber>;
         template class NextStateGenerator<storm::RationalFunction>;
 #endif
diff --git a/src/storm/generator/NextStateGenerator.h b/src/storm/generator/NextStateGenerator.h
index 5ef4a639a..262d3b29f 100644
--- a/src/storm/generator/NextStateGenerator.h
+++ b/src/storm/generator/NextStateGenerator.h
@@ -11,6 +11,7 @@
 #include "storm/storage/expressions/ExpressionEvaluator.h"
 #include "storm/storage/sparse/ChoiceOrigins.h"
 #include "storm/storage/sparse/StateValuations.h"
+#include "storm/storage/expressions/SimpleValuation.h"
 
 #include "storm/builder/BuilderOptions.h"
 #include "storm/builder/RewardModelInformation.h"
@@ -32,19 +33,41 @@ namespace storm {
             MA,
             POMDP
         };
-        
+
+        template<typename ValueType, typename StateType = uint32_t>
+        class NextStateGenerator;
+
+        template<typename ValueType, typename StateType = uint32_t>
+        class ActionMask {
+        public:
+            virtual ~ActionMask() = default;
+            virtual bool query(storm::generator::NextStateGenerator<ValueType, StateType> const &generator, uint64_t actionIndex) = 0;
+        };
+
         template<typename ValueType, typename StateType = uint32_t>
+        class StateValuationFunctionMask : public ActionMask<ValueType,StateType> {
+        public:
+            StateValuationFunctionMask(std::function<bool (storm::expressions::SimpleValuation const&, uint64_t)> const& f);
+            virtual ~StateValuationFunctionMask() = default;
+            bool query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) override;
+        private:
+            std::function<bool(storm::expressions::SimpleValuation, uint64_t)> func;
+        };
+
+
+        
+        template<typename ValueType, typename StateType>
         class NextStateGenerator {
         public:
             typedef std::function<StateType (CompressedState const&)> StateToIdCallback;
 
-            NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options);
+            NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
             
             /*!
              * Creates a new next state generator. This version of the constructor default-constructs the variable information.
              * Hence, the subclass is responsible for suitably initializing it in its constructor.
              */
-            NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options);
+            NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
             
             virtual ~NextStateGenerator() = default;
             
@@ -73,6 +96,7 @@ namespace storm {
             std::string stateToString(CompressedState const& state) const;
 
             storm::json<ValueType> currentStateToJson(bool onlyObservable = false) const;
+            storm::expressions::SimpleValuation currentStateToSimpleValuation() const;
 
             uint32_t observabilityClass(CompressedState const& state) const;
 
@@ -147,6 +171,8 @@ namespace storm {
             /// A map that stores the indices of states with overlapping guards.
             boost::optional<std::vector<uint64_t>> overlappingGuardStates;
 
+            std::shared_ptr<ActionMask<ValueType,StateType>> actionMask;
+
         };
     }
 }
diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp
index e9df107c5..02ca39e4b 100644
--- a/src/storm/generator/PrismNextStateGenerator.cpp
+++ b/src/storm/generator/PrismNextStateGenerator.cpp
@@ -22,12 +22,12 @@ namespace storm {
     namespace generator {
         
         template<typename ValueType, typename StateType>
-        PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, false) {
+        PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, mask, false) {
             // Intentionally left empty.
         }
         
         template<typename ValueType, typename StateType>
-        PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options), program(program), rewardModels(), hasStateActionRewards(false) {
+        PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options,  std::shared_ptr<ActionMask<ValueType,StateType>> const& mask, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options, mask), program(program), rewardModels(), hasStateActionRewards(false) {
             STORM_LOG_TRACE("Creating next-state generator for PRISM program: " << program);
             STORM_LOG_THROW(!this->program.specifiesSystemComposition(), storm::exceptions::WrongFormatException, "The explicit next-state generator currently does not support custom system compositions.");
                         
@@ -536,6 +536,11 @@ namespace storm {
                             continue;
                         }
                     }
+                    if (this->actionMask != nullptr) {
+                        if (!this->actionMask->query(*this, command.getActionIndex())) {
+                            continue;
+                        }
+                    }
 
                     // Skip the command, if it is not enabled.
                     if (!this->evaluator->asBool(command.getGuardExpression())) {
@@ -616,6 +621,11 @@ namespace storm {
         void PrismNextStateGenerator<ValueType, StateType>::addLabeledChoices(std::vector<Choice<ValueType>>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) {
 
             for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) {
+                if (this->actionMask != nullptr) {
+                    if (!this->actionMask->query(*this, actionIndex)) {
+                        continue;
+                    }
+                }
                 boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex, commandFilter);
 
                 // Only process this action label, if there is at least one feasible solution.
diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h
index 5f06bffa5..447db6525 100644
--- a/src/storm/generator/PrismNextStateGenerator.h
+++ b/src/storm/generator/PrismNextStateGenerator.h
@@ -23,7 +23,7 @@ namespace storm {
             typedef storm::storage::FlatSet<uint_fast64_t> CommandSet;
             enum class CommandFilter {All, Markovian, Probabilistic};
 
-            PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions());
+            PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions(), std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
             
             /*!
              * A quick check to detect whether the given model is not supported.
@@ -55,7 +55,7 @@ namespace storm {
              * being called. The last argument is only present to distinguish the signature of this constructor from the
              * public one.
              */
-            PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool flag);
+            PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const&, bool flag);
             
             /*!
              * Applies an update to the state currently loaded into the evaluator and applies the resulting values to