From a348f6ea8e828aa7de690250c157d45b9bc6caed Mon Sep 17 00:00:00 2001
From: TimQu <tim.quatmann@cs.rwth-aachen.de>
Date: Tue, 18 Jul 2017 11:24:24 +0200
Subject: [PATCH] function to apply a given scheduler to a nondeterministic
 model

---
 .../models/sparse/NondeterministicModel.cpp   | 20 +++++++++++++++++++
 .../models/sparse/NondeterministicModel.h     | 14 +++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/src/storm/models/sparse/NondeterministicModel.cpp b/src/storm/models/sparse/NondeterministicModel.cpp
index e823d455a..5b5aebe89 100644
--- a/src/storm/models/sparse/NondeterministicModel.cpp
+++ b/src/storm/models/sparse/NondeterministicModel.cpp
@@ -2,6 +2,9 @@
 
 #include "storm/models/sparse/StandardRewardModel.h"
 #include "storm/models/sparse/MarkovAutomaton.h"
+#include "storm/storage/Scheduler.h"
+#include "storm/storage/memorystructure/MemoryStructureBuilder.h"
+#include "storm/storage/memorystructure/SparseModelMemoryProduct.h"
 
 #include "storm/adapters/RationalFunctionAdapter.h"
 
@@ -46,6 +49,23 @@ namespace storm {
                 }
             }
             
+            template<typename ValueType, typename RewardModelType>
+            std::shared_ptr<storm::models::sparse::Model<ValueType, RewardModelType>> NondeterministicModel<ValueType, RewardModelType>::applyScheduler(storm::storage::Scheduler<ValueType> const& scheduler, bool dropUnreachableStates) {
+                boost::optional<storm::storage::SparseModelMemoryProduct<ValueType>> memoryProduct;
+                if (scheduler.isMemorylessScheduler()) {
+                    storm::storage::MemoryStructure memStruct = storm::storage::MemoryStructureBuilder<ValueType, RewardModelType>::buildTrivialMemoryStructure(*this);
+                    memoryProduct = memStruct.product(*this);
+                } else {
+                    boost::optional<storm::storage::MemoryStructure> const& memStruct = scheduler.getMemoryStructure();
+                    STORM_LOG_ASSERT(memStruct, "Memoryless scheduler without memory structure.");
+                    memoryProduct = memStruct->product(*this);
+                }
+                if (!dropUnreachableStates) {
+                    memoryProduct->setBuildFullProduct();
+                }
+                return memoryProduct->build(scheduler);
+            }
+            
             template<typename ValueType, typename RewardModelType>
             void NondeterministicModel<ValueType, RewardModelType>::printModelInformationToStream(std::ostream& out) const {
                 this->printModelInformationHeaderToStream(out);
diff --git a/src/storm/models/sparse/NondeterministicModel.h b/src/storm/models/sparse/NondeterministicModel.h
index 1f2bd157a..ab0a21a04 100644
--- a/src/storm/models/sparse/NondeterministicModel.h
+++ b/src/storm/models/sparse/NondeterministicModel.h
@@ -5,6 +5,13 @@
 #include "storm/utility/OsDetection.h"
 
 namespace storm {
+    
+    // Forward declare Scheduler class.
+    namespace storage {
+        template <typename ValueType>
+        class Scheduler;
+    }
+    
     namespace models {
         namespace sparse {
             
@@ -48,6 +55,13 @@ namespace storm {
                 
                 virtual void reduceToStateBasedRewards() override;
                 
+                /*!
+                 * Applies the given scheduler to this model.
+                 * @param scheduler the considered scheduler.
+                 * @param dropUnreachableStates if set, the resulting model only considers the states that are reachable from an initial state
+                 */
+                std::shared_ptr<storm::models::sparse::Model<ValueType, RewardModelType>> applyScheduler(storm::storage::Scheduler<ValueType> const& scheduler, bool dropUnreachableStates = true);
+                
                 virtual void printModelInformationToStream(std::ostream& out) const override;
                 
                 virtual void writeDotToStream(std::ostream& outStream, bool includeLabeling = true, storm::storage::BitVector const* subsystem = nullptr, std::vector<ValueType> const* firstValue = nullptr, std::vector<ValueType> const* secondValue = nullptr, std::vector<uint_fast64_t> const* stateColoring = nullptr, std::vector<std::string> const* colors = nullptr, std::vector<uint_fast64_t>* scheduler = nullptr, bool finalizeOutput = true) const override;