#ifndef STORM_SOLVER_LINEAREQUATIONSOLVER_H_
#define STORM_SOLVER_LINEAREQUATIONSOLVER_H_

#include <vector>
#include <memory>

#include "storm/solver/AbstractEquationSolver.h"
#include "storm/solver/MultiplicationStyle.h"
#include "storm/solver/LinearEquationSolverProblemFormat.h"
#include "storm/solver/LinearEquationSolverRequirements.h"
#include "storm/solver/OptimizationDirection.h"

#include "storm/utility/VectorHelper.h"

#include "storm/storage/SparseMatrix.h"

namespace storm {
    namespace solver {
        
        enum class LinearEquationSolverOperation {
            SolveEquations, MultiplyRepeatedly
        };

        /*!
         * An interface that represents an abstract linear equation solver. In addition to solving a system of linear
         * equations, the functionality to repeatedly multiply a matrix with a given vector is provided.
         */
        template<class ValueType>
        class LinearEquationSolver : public AbstractEquationSolver<ValueType> {
        public:
            LinearEquationSolver();

            virtual ~LinearEquationSolver() {
                // Intentionally left empty.
            }

            virtual void setMatrix(storm::storage::SparseMatrix<ValueType> const& A) = 0;
            virtual void setMatrix(storm::storage::SparseMatrix<ValueType>&& A) = 0;

            /*!
             * If the solver expects the equation system format, it solves Ax = b. If it it expects a fixed point
             * format, it solves Ax + b = x. In both versions, the matrix A is required to be square and the problem
             * is required to have a unique solution. The solution will be written to the vector x. Note that the matrix
             * A has to be given upon construction time of the solver object.
             *
             * @param x The solution vector that has to be computed. Its length must be equal to the number of rows of A.
             * @param b The vector b. Its length must be equal to the number of rows of A.
             *
             * @return true
             */
            bool solveEquations(std::vector<ValueType>& x, std::vector<ValueType> const& b) const;

            /*!
             * Performs on matrix-vector multiplication x' = A*x + b.
             *
             * @param x The input vector with which to multiply the matrix. Its length must be equal
             * to the number of columns of A.
             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
             * to the number of rows of A.
             * @param result The target vector into which to write the multiplication result. Its length must be equal
             * to the number of rows of A.
             */
            virtual void multiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const = 0;
            
            /*!
             * Performs on matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups
             * so that the resulting vector has the size of number of row groups of A.
             *
             * @param dir The direction for the reduction step.
             * @param rowGroupIndices A vector storing the row groups over which to reduce.
             * @param x The input vector with which to multiply the matrix. Its length must be equal
             * to the number of columns of A.
             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
             * to the number of rows of A.
             * @param result The target vector into which to write the multiplication result. Its length must be equal
             * to the number of rows of A.
             * @param choices If given, the choices made in the reduction process are written to this vector.
             */
            virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices = nullptr) const;
            
            /*!
             * Retrieves whether this solver offers the gauss-seidel style multiplications.
             */
            virtual bool supportsGaussSeidelMultiplication() const;
            
            /*!
             * Performs on matrix-vector multiplication x' = A*x + b. It does so in a gauss-seidel style, i.e. reusing
             * the new x' components in the further multiplication.
             *
             * @param x The input vector with which to multiply the matrix. Its length must be equal
             * to the number of columns of A.
             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
             * to the number of rows of A.
             */
            virtual void multiplyGaussSeidel(std::vector<ValueType>& x, std::vector<ValueType> const* b) const;
            
            /*!
             * Performs on matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups
             * so that the resulting vector has the size of number of row groups of A. It does so in a gauss-seidel
             * style, i.e. reusing the new x' components in the further multiplication.
             *
             * @param dir The direction for the reduction step.
             * @param rowGroupIndices A vector storing the row groups over which to reduce.
             * @param x The input vector with which to multiply the matrix. Its length must be equal
             * to the number of columns of A.
             * @param b If non-null, this vector is added after the multiplication. If given, its length must be equal
             * to the number of rows of A.
             * @param choices If given, the choices made in the reduction process are written to this vector.
             */
            virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices = nullptr) const;
            
            /*!
             * Performs repeated matrix-vector multiplication, using x[0] = x and x[i + 1] = A*x[i] + b. After
             * performing the necessary multiplications, the result is written to the input vector x. Note that the
             * matrix A has to be given upon construction time of the solver object.
             *
             * @param x The initial vector with which to perform matrix-vector multiplication. Its length must be equal
             * to the number of columns of A.
             * @param b If non-null, this vector is added after each multiplication. If given, its length must be equal
             * to the number of rows of A.
             * @param n The number of times to perform the multiplication.
             */
            void repeatedMultiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, uint_fast64_t n) const;
            
            /*!
             * Retrieves the format in which this solver expects to solve equations. If the solver expects the equation
             * system format, it solves Ax = b. If it it expects a fixed point format, it solves Ax + b = x.
             */
            virtual LinearEquationSolverProblemFormat getEquationProblemFormat() const = 0;
            
            /*!
             * Retrieves the requirements of the solver under the current settings. Note that these requirements only
             * apply to solving linear equations and not to the matrix vector multiplications.
             */
            virtual LinearEquationSolverRequirements getRequirements() const;
            
            /*!
             * Sets whether some of the generated data during solver calls should be cached.
             * This possibly increases the runtime of subsequent calls but also increases memory consumption.
             */
            void setCachingEnabled(bool value) const;
            
            /*!
             * Retrieves whether some of the generated data during solver calls should be cached.
             */
            bool isCachingEnabled() const;
            
            /*
             * Clears the currently cached data that has been stored during previous calls of the solver.
             */
            virtual void clearCache() const;
            
        protected:
            virtual bool internalSolveEquations(std::vector<ValueType>& x, std::vector<ValueType> const& b) const = 0;
                        
            // auxiliary storage. If set, this vector has getMatrixRowCount() entries.
            mutable std::unique_ptr<std::vector<ValueType>> cachedRowVector;
            
        private:
            /*!
             * Retrieves the row count of the matrix associated with this solver.
             */
            virtual uint64_t getMatrixRowCount() const = 0;

            /*!
             * Retrieves the column count of the matrix associated with this solver.
             */
            virtual uint64_t getMatrixColumnCount() const = 0;
            
            /// Whether some of the generated data during solver calls should be cached.
            mutable bool cachingEnabled;
            
            /// An object that can be used to reduce vectors.
            storm::utility::VectorHelper<ValueType> vectorHelper;
        };
        
        enum class EquationSolverType;
        
        template<typename ValueType>
        class LinearEquationSolverFactory {
        public:
            virtual ~LinearEquationSolverFactory() = default;

            /*!
             * Creates a new linear equation solver instance with the given matrix.
             *
             * @param matrix The matrix that defines the equation system.
             * @return A pointer to the newly created solver.
             */
            std::unique_ptr<LinearEquationSolver<ValueType>> create(storm::storage::SparseMatrix<ValueType> const& matrix) const;

            /*!
             * Creates a new linear equation solver instance with the given matrix. The caller gives up posession of the
             * matrix by calling this function.
             *
             * @param matrix The matrix that defines the equation system.
             * @return A pointer to the newly created solver.
             */
            std::unique_ptr<LinearEquationSolver<ValueType>> create(storm::storage::SparseMatrix<ValueType>&& matrix) const;

            /*!
             * Creates an equation solver with the current settings, but without a matrix.
             */
            virtual std::unique_ptr<LinearEquationSolver<ValueType>> create() const = 0;

            /*!
             * Creates a copy of this factory.
             */
            virtual std::unique_ptr<LinearEquationSolverFactory<ValueType>> clone() const = 0;

            /*!
             * Retrieves the problem format that the solver expects if it was created with the current settings.
             */
            virtual LinearEquationSolverProblemFormat getEquationProblemFormat() const;
            
            /*!
             * Retrieves the requirements of the solver if it was created with the current settings. Note that these
             * requirements only apply to solving linear equations and not to the matrix vector multiplications.
             */
            LinearEquationSolverRequirements getRequirements() const;
        };

        template<typename ValueType>
        class GeneralLinearEquationSolverFactory : public LinearEquationSolverFactory<ValueType> {
        public:
            GeneralLinearEquationSolverFactory();
            GeneralLinearEquationSolverFactory(EquationSolverType const& equationSolver);
            
            using LinearEquationSolverFactory<ValueType>::create;

            virtual std::unique_ptr<LinearEquationSolver<ValueType>> create() const override;

            virtual std::unique_ptr<LinearEquationSolverFactory<ValueType>> clone() const override;
            
        private:
            /*!
             * Sets the equation solver type.
             */
            void setEquationSolverType(EquationSolverType const& equationSolver);
            
            // The equation solver type.
            EquationSolverType equationSolver;
        };

#ifdef STORM_HAVE_CARL
        template<>
        class GeneralLinearEquationSolverFactory<storm::RationalNumber> : public LinearEquationSolverFactory<storm::RationalNumber> {
        public:
            using LinearEquationSolverFactory<storm::RationalNumber>::create;

            virtual std::unique_ptr<LinearEquationSolver<storm::RationalNumber>> create() const override;

            virtual std::unique_ptr<LinearEquationSolverFactory<storm::RationalNumber>> clone() const override;
        };

        template<>
        class GeneralLinearEquationSolverFactory<storm::RationalFunction> : public LinearEquationSolverFactory<storm::RationalFunction> {
        public:
            using LinearEquationSolverFactory<storm::RationalFunction>::create;

            virtual std::unique_ptr<LinearEquationSolver<storm::RationalFunction>> create() const override;

            virtual std::unique_ptr<LinearEquationSolverFactory<storm::RationalFunction>> clone() const override;
        };
#endif
    } // namespace solver
} // namespace storm

#endif /* STORM_SOLVER_LINEAREQUATIONSOLVER_H_ */