From c83721066cd8553d61d4fd3060f27d16884d537d Mon Sep 17 00:00:00 2001
From: Tim Quatmann <tim.quatmann@cs.rwth-aachen.de>
Date: Mon, 9 Mar 2020 14:00:09 +0100
Subject: [PATCH] Use acyclic solver in reward bounded properties.

---
 .../prctl/helper/rewardbounded/EpochModel.cpp | 49 +++++++++++++++----
 1 file changed, 40 insertions(+), 9 deletions(-)

diff --git a/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp b/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp
index f48253fca..63c1e2c76 100644
--- a/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp
+++ b/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp
@@ -1,7 +1,12 @@
 #include "storm/modelchecker/prctl/helper/rewardbounded/EpochModel.h"
 #include "storm/modelchecker/prctl/helper/rewardbounded/MultiDimensionalRewardUnfolding.h"
 
+#include "storm/utility/graph.h"
+#include "storm/environment/solver/MinMaxSolverEnvironment.h"
+#include "storm/environment/solver/SolverEnvironment.h"
+
 #include "storm/exceptions/UncheckedRequirementException.h"
+#include "storm/exceptions/UnexpectedException.h"
 
 namespace storm {
     namespace modelchecker {
@@ -46,9 +51,20 @@ namespace storm {
                     if (epochModel.epochMatrixChanged) {
                         x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
                         storm::solver::GeneralLinearEquationSolverFactory<ValueType> linearEquationSolverFactory;
-                        linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix);
+                        // We only check for acyclic models if the equation problem has the fixedPointSystem format.
+                        // We could also do this for other formats, however, this requires either matrix conversions or a different 'hasCycle' implementation.
+                        // Also, we would have to match the equationProblemFormat of the acyclic solver.
+                        bool epochMatrixAcyclic = epochModel.equationSolverProblemFormat.get() == storm::solver::LinearEquationSolverProblemFormat::FixedPointSystem && !storm::utility::graph::hasCycle(epochModel.epochMatrix);
+                        Environment acyclicEnv;
+                        if (epochMatrixAcyclic) {
+                            acyclicEnv = env;
+                            acyclicEnv.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Acyclic);
+                            linEqSolver = linearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix);
+                        } else {
+                            linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix);
+                        }
                         linEqSolver->setCachingEnabled(true);
-                        auto req = linEqSolver->getRequirements(env);
+                        auto req = linEqSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env);
                         if (lowerBound) {
                             linEqSolver->setLowerBound(lowerBound.get());
                             req.clearLowerBounds();
@@ -57,7 +73,11 @@ namespace storm {
                             linEqSolver->setUpperBound(upperBound.get());
                             req.clearUpperBounds();
                         }
+                        if (epochMatrixAcyclic) {
+                            req.clearAcyclic();
+                        }
                         STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked.");
+                        STORM_LOG_THROW(linEqSolver->getEquationProblemFormat(epochMatrixAcyclic ? acyclicEnv : env) == epochModel.equationSolverProblemFormat.get(), storm::exceptions::UnexpectedException, "The constructed solver uses a different equation problem format then the one that has been specified initially.");
                     }
 
                     // Prepare the right hand side of the equation system
@@ -79,8 +99,6 @@ namespace storm {
                     return storm::utility::vector::filterVector(x, epochModel.epochInStates);
                 }
 
-
-
                 template<typename ValueType>
                 std::vector<ValueType> analyzeTrivialMdpEpochModel(OptimizationDirection dir, EpochModel<ValueType, true>& epochModel) {
                     // Assert that the epoch model is indeed trivial
@@ -138,13 +156,21 @@ namespace storm {
                     if (epochModel.epochMatrixChanged) {
                         x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
                         storm::solver::GeneralMinMaxLinearEquationSolverFactory<ValueType> minMaxLinearEquationSolverFactory;
-                        minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix);
+                        bool epochMatrixAcyclic = !storm::utility::graph::hasCycle(epochModel.epochMatrix);
+                        Environment acyclicEnv;
+                        if (epochMatrixAcyclic) {
+                            acyclicEnv = env;
+                            acyclicEnv.solver().minMax().setMethod(storm::solver::MinMaxMethod::Acyclic);
+                            minMaxSolver = minMaxLinearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix);
+                        } else {
+                            minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix);
+                        }
                         minMaxSolver->setHasUniqueSolution();
                         minMaxSolver->setHasNoEndComponents();
                         minMaxSolver->setOptimizationDirection(dir);
                         minMaxSolver->setCachingEnabled(true);
-                        minMaxSolver->setTrackScheduler(true);
-                        auto req = minMaxSolver->getRequirements(env, dir, false);
+                        minMaxSolver->setTrackScheduler(!epochMatrixAcyclic); // only track the scheduler if there are cycles
+                        auto req = minMaxSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env, dir, false);
                         if (lowerBound) {
                             minMaxSolver->setLowerBound(lowerBound.get());
                             req.clearLowerBounds();
@@ -153,11 +179,16 @@ namespace storm {
                             minMaxSolver->setUpperBound(upperBound.get());
                             req.clearUpperBounds();
                         }
+                        if (epochMatrixAcyclic) {
+                            req.clearAcyclic();
+                        }
                         STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked.");
                         minMaxSolver->setRequirementsChecked();
                     } else {
-                        auto choicesTmp = minMaxSolver->getSchedulerChoices();
-                        minMaxSolver->setInitialScheduler(std::move(choicesTmp));
+                        if (minMaxSolver && minMaxSolver->isTrackSchedulerSet()) {
+                            auto choicesTmp = minMaxSolver->getSchedulerChoices();
+                            minMaxSolver->setInitialScheduler(std::move(choicesTmp));
+                        }
                     }
 
                     // Prepare the right hand side of the equation system