#include "gtest/gtest.h"
#include "storm-config.h"
#include "test/storm_gtest.h"

#include "storm/solver/LinearEquationSolver.h"
#include "storm/environment/solver/NativeSolverEnvironment.h"
#include "storm/environment/solver/GmmxxSolverEnvironment.h"
#include "storm/environment/solver/EigenSolverEnvironment.h"

#include "storm/utility/vector.h"
namespace {
    
    class NativeDoublePowerEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::Power);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-10"));
            return env;
        }
    };
    
    class NativeDoubleSoundPowerEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setForceSoundness(true);
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::Power);
            env.solver().native().setRelativeTerminationCriterion(false);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-6"));
            return env;
        }
    };
    
    class NativeDoubleJacobiEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::Jacobi);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-10"));
            return env;
        }
    };
    
    class NativeDoubleGaussSeidelEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::GaussSeidel);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-10"));
            return env;
        }
    };
    
    class NativeDoubleSorEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::SOR);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-10"));
            return env;
        }
    };
    
    class NativeDoubleWalkerChaeEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::WalkerChae);
            env.solver().native().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            env.solver().native().setMaximalNumberOfIterations(500000);
            return env;
        }
    };
    
    class NativeRationalRationalSearchEnvironment {
    public:
        typedef storm::RationalNumber ValueType;
        static const bool isExact = true;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Native);
            env.solver().native().setMethod(storm::solver::NativeLinearEquationSolverMethod::RationalSearch);
            return env;
        }
    };
    
    class EliminationRationalEnvironment {
    public:
        typedef storm::RationalNumber ValueType;
        static const bool isExact = true;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Elimination);
            return env;
        }
    };
    
    class GmmGmresIluEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
            env.solver().gmmxx().setMethod(storm::solver::GmmxxLinearEquationSolverMethod::Gmres);
            env.solver().gmmxx().setPreconditioner(storm::solver::GmmxxLinearEquationSolverPreconditioner::Ilu);
            env.solver().gmmxx().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };

    class GmmGmresDiagonalEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
            env.solver().gmmxx().setMethod(storm::solver::GmmxxLinearEquationSolverMethod::Gmres);
            env.solver().gmmxx().setPreconditioner(storm::solver::GmmxxLinearEquationSolverPreconditioner::Diagonal);
            env.solver().gmmxx().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };

    class GmmGmresNoneEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
            env.solver().gmmxx().setMethod(storm::solver::GmmxxLinearEquationSolverMethod::Gmres);
            env.solver().gmmxx().setPreconditioner(storm::solver::GmmxxLinearEquationSolverPreconditioner::None);
            env.solver().gmmxx().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class GmmBicgstabIluEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
            env.solver().gmmxx().setMethod(storm::solver::GmmxxLinearEquationSolverMethod::Bicgstab);
            env.solver().gmmxx().setPreconditioner(storm::solver::GmmxxLinearEquationSolverPreconditioner::Ilu);
            env.solver().gmmxx().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class GmmQmrDiagonalEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Gmmxx);
            env.solver().gmmxx().setMethod(storm::solver::GmmxxLinearEquationSolverMethod::Qmr);
            env.solver().gmmxx().setPreconditioner(storm::solver::GmmxxLinearEquationSolverPreconditioner::Diagonal);
            env.solver().gmmxx().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class EigenDGmresDiagonalEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
            env.solver().eigen().setMethod(storm::solver::EigenLinearEquationSolverMethod::DGmres);
            env.solver().eigen().setPreconditioner(storm::solver::EigenLinearEquationSolverPreconditioner::Diagonal);
            env.solver().eigen().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class EigenGmresIluEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
            env.solver().eigen().setMethod(storm::solver::EigenLinearEquationSolverMethod::Gmres);
            env.solver().eigen().setPreconditioner(storm::solver::EigenLinearEquationSolverPreconditioner::Ilu);
            env.solver().eigen().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class EigenBicgstabNoneEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
            env.solver().eigen().setMethod(storm::solver::EigenLinearEquationSolverMethod::Bicgstab);
            env.solver().eigen().setPreconditioner(storm::solver::EigenLinearEquationSolverPreconditioner::None);
            env.solver().eigen().setPrecision(storm::utility::convertNumber<storm::RationalNumber, std::string>("1e-8"));
            return env;
        }
    };
    
    class EigenDoubleLUEnvironment {
    public:
        typedef double ValueType;
        static const bool isExact = false;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
            env.solver().eigen().setMethod(storm::solver::EigenLinearEquationSolverMethod::SparseLU);
            return env;
        }
    };
    
    class EigenRationalLUEnvironment {
    public:
        typedef storm::RationalNumber ValueType;
        static const bool isExact = true;
        static storm::Environment createEnvironment() {
            storm::Environment env;
            env.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Eigen);
            env.solver().eigen().setMethod(storm::solver::EigenLinearEquationSolverMethod::SparseLU);
            return env;
        }
    };
    
    template<typename TestType>
    class LinearEquationSolverTest : public ::testing::Test {
    public:
        typedef typename TestType::ValueType ValueType;
        LinearEquationSolverTest() : _environment(TestType::createEnvironment()) {}
        storm::Environment const& env() const { return _environment; }
        ValueType precision() const { return TestType::isExact ? parseNumber("0") : parseNumber("1e-6");}
        ValueType parseNumber(std::string const& input) const { return storm::utility::convertNumber<ValueType>(input);}
    private:
        storm::Environment _environment;
    };
  
    typedef ::testing::Types<
            NativeDoublePowerEnvironment,
            NativeDoubleSoundPowerEnvironment,
            NativeDoubleJacobiEnvironment,
            NativeDoubleGaussSeidelEnvironment,
            NativeDoubleSorEnvironment,
            NativeDoubleWalkerChaeEnvironment,
            NativeRationalRationalSearchEnvironment,
            EliminationRationalEnvironment,
            GmmGmresIluEnvironment,
            GmmGmresDiagonalEnvironment,
            GmmGmresNoneEnvironment,
            GmmBicgstabIluEnvironment,
            GmmQmrDiagonalEnvironment,
            EigenDGmresDiagonalEnvironment,
            EigenGmresIluEnvironment,
            EigenBicgstabNoneEnvironment,
            EigenDoubleLUEnvironment,
            EigenRationalLUEnvironment
    > TestingTypes;
    
    TYPED_TEST_CASE(LinearEquationSolverTest, TestingTypes);
    
    
    TYPED_TEST(LinearEquationSolverTest, solveEquationSystem) {
        typedef typename TestFixture::ValueType ValueType;
        ASSERT_NO_THROW(storm::storage::SparseMatrixBuilder<ValueType> builder);
        storm::storage::SparseMatrixBuilder<ValueType> builder;
        ASSERT_NO_THROW(builder.addNextValue(0, 0, this->parseNumber("1/5")));
        ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("2/5")));
        ASSERT_NO_THROW(builder.addNextValue(0, 2, this->parseNumber("2/5")));
        ASSERT_NO_THROW(builder.addNextValue(1, 0, this->parseNumber("1/50")));
        ASSERT_NO_THROW(builder.addNextValue(1, 1, this->parseNumber("48/50")));
        ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("1/50")));
        ASSERT_NO_THROW(builder.addNextValue(2, 0, this->parseNumber("4/10")));
        ASSERT_NO_THROW(builder.addNextValue(2, 1, this->parseNumber("3/10")));
        ASSERT_NO_THROW(builder.addNextValue(2, 2, this->parseNumber("0")));
        
        storm::storage::SparseMatrix<ValueType> A;
        ASSERT_NO_THROW(A = builder.build());
        
        std::vector<ValueType> x(3);
        std::vector<ValueType> b = {this->parseNumber("3"), this->parseNumber("-0.01"), this->parseNumber("12")};
        
        auto factory = storm::solver::GeneralLinearEquationSolverFactory<ValueType>();
        if (factory.getEquationProblemFormat(this->env()) == storm::solver::LinearEquationSolverProblemFormat::EquationSystem) {
            A.convertToEquationSystem();
        }
        
        auto requirements = factory.getRequirements(this->env());
        requirements.clearUpperBounds();
        requirements.clearLowerBounds();
        ASSERT_TRUE(requirements.empty());
        auto solver = factory.create(this->env(), A);
        solver->setBounds(this->parseNumber("-100"), this->parseNumber("100"));
        ASSERT_NO_THROW(solver->solveEquations(this->env(), x, b));
        EXPECT_NEAR(x[0], this->parseNumber("481/9"), this->precision());
        EXPECT_NEAR(x[1], this->parseNumber("457/9"), this->precision());
        EXPECT_NEAR(x[2], this->parseNumber("875/18"), this->precision());
    }
    
    TYPED_TEST(LinearEquationSolverTest, MatrixVectorMultiplication) {
        typedef typename TestFixture::ValueType ValueType;
        ASSERT_NO_THROW(storm::storage::SparseMatrixBuilder<ValueType> builder);
        storm::storage::SparseMatrixBuilder<ValueType> builder;
        ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(0, 4, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(1, 4, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(2, 3, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(2, 4, this->parseNumber("0.5")));
        ASSERT_NO_THROW(builder.addNextValue(3, 4, this->parseNumber("1")));
        ASSERT_NO_THROW(builder.addNextValue(4, 4, this->parseNumber("1")));
        
        storm::storage::SparseMatrix<ValueType> A;
        ASSERT_NO_THROW(A = builder.build());
        
        std::vector<ValueType> x(5);
        x[4] = this->parseNumber("1");

        auto factory = storm::solver::GeneralLinearEquationSolverFactory<ValueType>();
        auto solver = factory.create(this->env(), A);
        ASSERT_NO_THROW(solver->repeatedMultiply(x, nullptr, 4));
        EXPECT_NEAR(x[0], this->parseNumber("1"), this->precision());
    }
}