From 61a8b9bb29bcbcee9835b6c11758648d82f04b97 Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Mon, 27 Jun 2016 16:59:37 +0200
Subject: [PATCH] more work on solvers

Former-commit-id: 14fad8ac36ea0577422904a8ffded720a7fe99b9
---
 .../modules/MinMaxEquationSolverSettings.cpp  |   1 +
 src/solver/EigenLinearEquationSolver.cpp      |   1 +
 src/solver/LinearEquationSolver.cpp           |  39 ++++-
 src/solver/LinearEquationSolver.h             |  26 +--
 .../StandardMinMaxLinearEquationSolver.cpp    | 156 +++++++++++++++---
 .../StandardMinMaxLinearEquationSolver.h      |   9 +
 6 files changed, 195 insertions(+), 37 deletions(-)

diff --git a/src/settings/modules/MinMaxEquationSolverSettings.cpp b/src/settings/modules/MinMaxEquationSolverSettings.cpp
index 8a7067094..6b408d9ff 100644
--- a/src/settings/modules/MinMaxEquationSolverSettings.cpp
+++ b/src/settings/modules/MinMaxEquationSolverSettings.cpp
@@ -16,6 +16,7 @@ namespace storm {
             const std::string MinMaxEquationSolverSettings::maximalIterationsOptionName = "maxiter";
             const std::string MinMaxEquationSolverSettings::maximalIterationsOptionShortName = "i";
             const std::string MinMaxEquationSolverSettings::precisionOptionName = "precision";
+            const std::string MinMaxEquationSolverSettings::absoluteOptionName = "absolute";
 
             MinMaxEquationSolverSettings::MinMaxEquationSolverSettings() : ModuleSettings(moduleName) {
                 std::vector<std::string> minMaxSolvingTechniques = {"vi", "value-iteration", "pi", "policy-iteration"};
diff --git a/src/solver/EigenLinearEquationSolver.cpp b/src/solver/EigenLinearEquationSolver.cpp
index cac22e645..e7d4bd2d6 100644
--- a/src/solver/EigenLinearEquationSolver.cpp
+++ b/src/solver/EigenLinearEquationSolver.cpp
@@ -251,6 +251,7 @@ namespace storm {
                 } else {
                     nextX->noalias() = *eigenA * *currentX;
                 }
+                std::swap(nextX, currentX);
             }
             
             // If the last result we obtained is not the one in the input vector x, we swap the result there.
diff --git a/src/solver/LinearEquationSolver.cpp b/src/solver/LinearEquationSolver.cpp
index 4cbfb16da..b7f15d90d 100644
--- a/src/solver/LinearEquationSolver.cpp
+++ b/src/solver/LinearEquationSolver.cpp
@@ -13,6 +13,39 @@
 namespace storm {
     namespace solver {
         
+        template<typename ValueType>
+        void LinearEquationSolver<ValueType>::performMatrixVectorMultiplication(std::vector<ValueType>& x, std::vector<ValueType> const* b, uint_fast64_t n, std::vector<ValueType>* multiplyResult) const {
+            
+            // Set up some temporary variables so that we can just swap pointers instead of copying the result after
+            // each iteration.
+            std::vector<ValueType>* currentX = &x;
+            
+            bool multiplyResultProvided = true;
+            std::vector<ValueType>* nextX = multiplyResult;
+            if (nextX == nullptr) {
+                nextX = new std::vector<ValueType>(x.size());
+                multiplyResultProvided = false;
+            }
+            std::vector<ValueType> const* copyX = nextX;
+            
+            // Now perform matrix-vector multiplication as long as we meet the bound.
+            for (uint_fast64_t i = 0; i < n; ++i) {
+                this->performMatrixVectorMultiplication(*currentX, *nextX, b);
+                std::swap(nextX, currentX);
+            }
+            
+            // If we performed an odd number of repetitions, we need to swap the contents of currentVector and x,
+            // because the output is supposed to be stored in the input vector x.
+            if (currentX == copyX) {
+                std::swap(x, *currentX);
+            }
+            
+            // If the vector for the temporary multiplication result was not provided, we need to delete it.
+            if (!multiplyResultProvided) {
+                delete copyX;
+            }
+        }
+        
         template<typename ValueType>
         std::unique_ptr<LinearEquationSolver<ValueType>> LinearEquationSolverFactory<ValueType>::create(storm::storage::SparseMatrix<ValueType>&& matrix) const {
             return create(matrix);
@@ -87,7 +120,11 @@ namespace storm {
         std::unique_ptr<LinearEquationSolverFactory<storm::RationalFunction>> GeneralLinearEquationSolverFactory<storm::RationalFunction>::clone() const {
             return std::make_unique<GeneralLinearEquationSolverFactory<storm::RationalFunction>>(*this);
         }
-        
+
+        template class LinearEquationSolver<double>;
+        template class LinearEquationSolver<storm::RationalNumber>;
+        template class LinearEquationSolver<storm::RationalFunction>;
+
         template class LinearEquationSolverFactory<double>;
         template class LinearEquationSolverFactory<storm::RationalNumber>;
         template class LinearEquationSolverFactory<storm::RationalFunction>;
diff --git a/src/solver/LinearEquationSolver.h b/src/solver/LinearEquationSolver.h
index 8585f2136..e12fa6eb5 100644
--- a/src/solver/LinearEquationSolver.h
+++ b/src/solver/LinearEquationSolver.h
@@ -34,6 +34,18 @@ namespace storm {
              */
             virtual void solveEquationSystem(std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult = nullptr) const = 0;
             
+            /*!
+             * Performs on matrix-vector multiplication x' = A*x + b.
+             *
+             * @param x The input vector with which to multiply the matrix. Its length must be equal
+             * to the number of columns of A.
+             * @param result The target vector into which to write the multiplication result. Its length must be equal
+             * to the number of rows of A.
+             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
+             * to the number of rows of A.
+             */
+            virtual void performMatrixVectorMultiplication(std::vector<ValueType>& x, std::vector<ValueType>& result, std::vector<ValueType> const* b = nullptr) const = 0;
+            
             /*!
              * Performs repeated matrix-vector multiplication, using x[0] = x and x[i + 1] = A*x[i] + b. After
              * performing the necessary multiplications, the result is written to the input vector x. Note that the
@@ -46,19 +58,7 @@ namespace storm {
              * @param multiplyResult If non-null, this memory is used as a scratch memory. If given, the length of this
              * vector must be equal to the number of rows of A.
              */
-            virtual void performMatrixVectorMultiplication(std::vector<ValueType>& x, std::vector<ValueType> const* b = nullptr, uint_fast64_t n = 1, std::vector<ValueType>* multiplyResult = nullptr) const = 0;
-            
-            /*!
-             * Performs on matrix-vector multiplication x' = A*x + b.
-             *
-             * @param x The input vector with which to multiply the matrix. Its length must be equal
-             * to the number of columns of A.
-             * @param result The target vector into which to write the multiplication result. Its length must be equal
-             * to the number of rows of A.
-             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
-             * to the number of rows of A.
-             */
-            virtual void performMatrixVectorMultiplication(std::vector<ValueType>& x, std::vector<ValueType>& result, std::vector<ValueType> const* b = nullptr) const = 0;
+            void performMatrixVectorMultiplication(std::vector<ValueType>& x, std::vector<ValueType> const* b = nullptr, uint_fast64_t n = 1, std::vector<ValueType>* multiplyResult = nullptr) const;
         };
         
         template<typename ValueType>
diff --git a/src/solver/StandardMinMaxLinearEquationSolver.cpp b/src/solver/StandardMinMaxLinearEquationSolver.cpp
index d6213ddd7..0e9791969 100644
--- a/src/solver/StandardMinMaxLinearEquationSolver.cpp
+++ b/src/solver/StandardMinMaxLinearEquationSolver.cpp
@@ -11,15 +11,17 @@
 #include "src/utility/vector.h"
 #include "src/utility/macros.h"
 #include "src/exceptions/InvalidSettingsException.h"
-
+#include "src/exceptions/InvalidStateException.h"
 namespace storm {
     namespace solver {
         
         StandardMinMaxLinearEquationSolverSettings::StandardMinMaxLinearEquationSolverSettings() {
             // Get the settings object to customize linear solving.
             storm::settings::modules::MinMaxEquationSolverSettings const& settings = storm::settings::getModule<storm::settings::modules::MinMaxEquationSolverSettings>();
-
+            
             maximalNumberOfIterations = settings.getMaximalIterationCount();
+            precision = settings.getPrecision();
+            relative = settings.getConvergenceCriterion() == storm::settings::modules::MinMaxEquationSolverSettings::ConvergenceCriterion::Relative;
             
             auto method = settings.getMinMaxEquationSolvingMethod();
             switch (method) {
@@ -66,7 +68,7 @@ namespace storm {
         StandardMinMaxLinearEquationSolver<ValueType>::StandardMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType> const& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory, StandardMinMaxLinearEquationSolverSettings const& settings) : settings(settings), linearEquationSolverFactory(std::move(linearEquationSolverFactory)), localA(nullptr), A(A) {
             // Intentionally left empty.
         }
-
+        
         template<typename ValueType>
         StandardMinMaxLinearEquationSolver<ValueType>::StandardMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType>&& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory, StandardMinMaxLinearEquationSolverSettings const& settings) : settings(settings), linearEquationSolverFactory(std::move(linearEquationSolverFactory)), localA(std::make_unique<storm::storage::SparseMatrix<ValueType>>(std::move(A))), A(*localA) {
             // Intentionally left empty.
@@ -82,11 +84,100 @@ namespace storm {
         
         template<typename ValueType>
         void StandardMinMaxLinearEquationSolver<ValueType>::solveEquationSystemPolicyIteration(OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult, std::vector<ValueType>* newX) const {
-            // FIXME.
+            
+            // Create scratch memory if none was provided.
+            bool multiplyResultMemoryProvided = multiplyResult != nullptr;
+            if (multiplyResult == nullptr) {
+                multiplyResult = new std::vector<ValueType>(this->A.getRowCount());
+            }
+
+            // Create the initial scheduler.
+            std::vector<storm::storage::sparse::state_type> scheduler(this->A.getRowGroupCount());
+            
+            // Create a vector for storing the right-hand side of the inner equation system.
+            std::vector<ValueType> subB(this->A.getRowGroupCount());
+            
+            // Create a vector that the inner equation solver can use as scratch memory.
+            std::vector<ValueType> deterministicMultiplyResult(this->A.getRowGroupCount());
+            
+            Status status = Status::InProgress;
+            uint64_t iterations = 0;
+            do {
+                // Resolve the nondeterminism according to the current scheduler.
+                storm::storage::SparseMatrix<ValueType> submatrix = this->A.selectRowsFromRowGroups(scheduler, true);
+                submatrix.convertToEquationSystem();
+                storm::utility::vector::selectVectorValues<ValueType>(subB, scheduler, this->A.getRowGroupIndices(), b);
+
+                // Solve the equation system for the 'DTMC'.
+                // FIXME: we need to remove the 0- and 1- states to make the solution unique.
+                auto solver = linearEquationSolverFactory->create(submatrix);
+                solver->solveEquationSystem(x, subB, &deterministicMultiplyResult);
+
+                // Go through the multiplication result and see whether we can improve any of the choices.
+                bool schedulerImproved = false;
+                for (uint_fast64_t group = 0; group < this->A.getRowGroupCount(); ++group) {
+                    for (uint_fast64_t choice = this->A.getRowGroupIndices()[group]; choice < this->A.getRowGroupIndices()[group + 1]; ++choice) {
+                        // If the choice is the currently selected one, we can skip it.
+                        if (choice - this->A.getRowGroupIndices()[group] == scheduler[group]) {
+                            continue;
+                        }
+                            
+                        // Create the value of the choice.
+                        ValueType choiceValue = storm::utility::zero<ValueType>();
+                        for (auto const& entry : this->A.getRow(choice)) {
+                            choiceValue += entry.getValue() * x[entry.getColumn()];
+                        }
+                        choiceValue += b[choice];
+                        
+                        // If the value is strictly better than the solution of the inner system, we need to improve the scheduler.
+                        if (valueImproved(dir, x[group], choiceValue)) {
+                            schedulerImproved = true;
+                            scheduler[group] = choice - this->A.getRowGroupIndices()[group];
+                        }
+                    }
+                }
+                
+                // If the scheduler did not improve, we are done.
+                if (!schedulerImproved) {
+                    status = Status::Converged;
+                }
+                
+                // Update environment variables.
+                ++iterations;
+                status = updateStatusIfNotConverged(status, x, iterations);
+            } while (status == Status::InProgress);
+            
+            reportStatus(status, iterations);
+            
+            // If requested, we store the scheduler for retrieval.
+            if (this->isTrackSchedulerSet()) {
+                this->scheduler = std::make_unique<storm::storage::TotalScheduler>(std::move(scheduler));
+            }
+            
+            if (!multiplyResultMemoryProvided) {
+                delete multiplyResult;
+            }
+        }
+        
+        template<typename ValueType>
+        bool StandardMinMaxLinearEquationSolver<ValueType>::valueImproved(OptimizationDirection dir, ValueType const& value1, ValueType const& value2) const {
+            if (dir == OptimizationDirection::Minimize) {
+                if (value1 > value2) {
+                    return true;
+                }
+                return false;
+            } else {
+                if (value1 < value2) {
+                    return true;
+                }
+                return false;
+            }
         }
         
         template<typename ValueType>
         void StandardMinMaxLinearEquationSolver<ValueType>::solveEquationSystemValueIteration(OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult, std::vector<ValueType>* newX) const {
+            std::unique_ptr<storm::solver::LinearEquationSolver<ValueType>> solver = linearEquationSolverFactory->create(A);
+            
             // Create scratch memory if none was provided.
             bool multiplyResultMemoryProvided = multiplyResult != nullptr;
             if (multiplyResult == nullptr) {
@@ -101,33 +192,29 @@ namespace storm {
             // Keep track of which of the vectors for x is the auxiliary copy.
             std::vector<ValueType>* copyX = newX;
             
+            // Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations.
             uint64_t iterations = 0;
-            bool converged = false;
             
-            // Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations.
-            while (!converged && iterations < this->getSettings().getMaximalNumberOfIterations() && (!this->hasCustomTerminationCondition() || this->getTerminationCondition().terminateNow(*currentX))) {
+            Status status = Status::InProgress;
+            while (status == Status::InProgress) {
                 // Compute x' = A*x + b.
-                this->A.multiplyWithVector(*currentX, *multiplyResult);
-                storm::utility::vector::addVectors(*multiplyResult, b, *multiplyResult);
+                solver->performMatrixVectorMultiplication(*currentX, *multiplyResult, &b);
                 
-                // Reduce the vector x' by applying min/max for all non-deterministic choices as given by the topmost
-                // element of the min/max operator stack.
+                // Reduce the vector x' by applying min/max for all non-deterministic choices.
                 storm::utility::vector::reduceVectorMinOrMax(dir, *multiplyResult, *newX, this->A.getRowGroupIndices());
                 
                 // Determine whether the method converged.
-                converged = storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *newX, static_cast<ValueType>(this->getSettings().getPrecision()), this->getSettings().getRelativeTerminationCriterion());
+                if (storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *newX, static_cast<ValueType>(this->getSettings().getPrecision()), this->getSettings().getRelativeTerminationCriterion())) {
+                    status = Status::Converged;
+                }
                 
                 // Update environment variables.
                 std::swap(currentX, newX);
                 ++iterations;
+                status = updateStatusIfNotConverged(status, *currentX, iterations);
             }
             
-            // Check if the solver converged and issue a warning otherwise.
-            if (converged) {
-                STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations.");
-            } else {
-                STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations.");
-            }
+            reportStatus(status, iterations);
             
             // If we performed an odd number of iterations, we need to swap the x and currentX, because the newest result
             // is currently stored in currentX, but x is the output vector.
@@ -135,10 +222,10 @@ namespace storm {
                 std::swap(x, *currentX);
             }
             
+            // Dispose of allocated scratch memory.
             if (!xMemoryProvided) {
                 delete copyX;
             }
-            
             if (!multiplyResultMemoryProvided) {
                 delete multiplyResult;
             }
@@ -167,11 +254,34 @@ namespace storm {
             }
         }
         
+        template<typename ValueType>
+        typename StandardMinMaxLinearEquationSolver<ValueType>::Status StandardMinMaxLinearEquationSolver<ValueType>::updateStatusIfNotConverged(Status status, std::vector<ValueType> const& x, uint64_t iterations) const {
+            if (status != Status::Converged) {
+                if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x)) {
+                    status = Status::TerminatedEarly;
+                } else if (iterations >= this->getSettings().getMaximalNumberOfIterations()) {
+                    status = Status::MaximalIterationsExceeded;
+                }
+            }
+            return status;
+        }
+        
+        template<typename ValueType>
+        void StandardMinMaxLinearEquationSolver<ValueType>::reportStatus(Status status, uint64_t iterations) const {
+            switch (status) {
+                case Status::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break;
+                case Status::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break;
+                case Status::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
+                default:
+                    STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
+            }
+        }
+        
         template<typename ValueType>
         StandardMinMaxLinearEquationSolverSettings const& StandardMinMaxLinearEquationSolver<ValueType>::getSettings() const {
             return settings;
         }
-
+        
         template<typename ValueType>
         StandardMinMaxLinearEquationSolverSettings& StandardMinMaxLinearEquationSolver<ValueType>::getSettings() {
             return settings;
@@ -234,12 +344,12 @@ namespace storm {
         GmmxxMinMaxLinearEquationSolverFactory<ValueType>::GmmxxMinMaxLinearEquationSolverFactory(bool trackScheduler) : StandardMinMaxLinearEquationSolverFactory<ValueType>(EquationSolverType::Gmmxx, trackScheduler) {
             // Intentionally left empty.
         }
-
+        
         template<typename ValueType>
         EigenMinMaxLinearEquationSolverFactory<ValueType>::EigenMinMaxLinearEquationSolverFactory(bool trackScheduler) : StandardMinMaxLinearEquationSolverFactory<ValueType>(EquationSolverType::Eigen, trackScheduler) {
             // Intentionally left empty.
         }
-
+        
         template<typename ValueType>
         NativeMinMaxLinearEquationSolverFactory<ValueType>::NativeMinMaxLinearEquationSolverFactory(bool trackScheduler) : StandardMinMaxLinearEquationSolverFactory<ValueType>(EquationSolverType::Native, trackScheduler) {
             // Intentionally left empty.
@@ -257,6 +367,6 @@ namespace storm {
         template class EigenMinMaxLinearEquationSolverFactory<double>;
         template class NativeMinMaxLinearEquationSolverFactory<double>;
         template class EliminationMinMaxLinearEquationSolverFactory<double>;
-
+        
     }
 }
\ No newline at end of file
diff --git a/src/solver/StandardMinMaxLinearEquationSolver.h b/src/solver/StandardMinMaxLinearEquationSolver.h
index 50b4fa88b..e3931a13d 100644
--- a/src/solver/StandardMinMaxLinearEquationSolver.h
+++ b/src/solver/StandardMinMaxLinearEquationSolver.h
@@ -47,6 +47,15 @@ namespace storm {
             void solveEquationSystemPolicyIteration(OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult, std::vector<ValueType>* newX) const;
             void solveEquationSystemValueIteration(OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult, std::vector<ValueType>* newX) const;
 
+            bool valueImproved(OptimizationDirection dir, ValueType const& value1, ValueType const& value2) const;
+            
+            enum class Status {
+                Converged, TerminatedEarly, MaximalIterationsExceeded, InProgress
+            };
+
+            Status updateStatusIfNotConverged(Status status, std::vector<ValueType> const& x, uint64_t iterations) const;
+            void reportStatus(Status status, uint64_t iterations) const;
+            
             /// The settings of this solver.
             StandardMinMaxLinearEquationSolverSettings settings;