#include "GmmxxLinearEquationSolver.h" #include #include #include "storm/adapters/GmmxxAdapter.h" #include "storm/solver/GmmxxMultiplier.h" #include "storm/environment/solver/GmmxxSolverEnvironment.h" #include "storm/utility/vector.h" #include "storm/utility/constants.h" #include "storm/utility/gmm.h" #include "storm/utility/vector.h" namespace storm { namespace solver { template GmmxxLinearEquationSolver::GmmxxLinearEquationSolver() { // Intentionally left empty. } template GmmxxLinearEquationSolver::GmmxxLinearEquationSolver(storm::storage::SparseMatrix const& A) { this->setMatrix(A); } template GmmxxLinearEquationSolver::GmmxxLinearEquationSolver(storm::storage::SparseMatrix&& A) { this->setMatrix(std::move(A)); } template void GmmxxLinearEquationSolver::setMatrix(storm::storage::SparseMatrix const& A) { gmmxxA = storm::adapters::GmmxxAdapter::toGmmxxSparseMatrix(A); clearCache(); } template void GmmxxLinearEquationSolver::setMatrix(storm::storage::SparseMatrix&& A) { gmmxxA = storm::adapters::GmmxxAdapter::toGmmxxSparseMatrix(A); clearCache(); } template GmmxxLinearEquationSolverMethod GmmxxLinearEquationSolver::getMethod(Environment const& env) const { STORM_LOG_ERROR_COND(!env.solver().isForceSoundness(), "This linear equation solver does not support sound computations. Using unsound methods now..."); return env.solver().gmmxx().getMethod(); } template bool GmmxxLinearEquationSolver::internalSolveEquations(Environment const& env, std::vector& x, std::vector const& b) const { auto method = getMethod(env); auto preconditioner = env.solver().gmmxx().getPreconditioner(); STORM_LOG_INFO("Solving linear equation system (" << x.size() << " rows) with Gmmxx linear equation solver with method '" << toString(method) << "' and preconditioner '" << toString(preconditioner) << "'."); if (method == GmmxxLinearEquationSolverMethod::Bicgstab || method == GmmxxLinearEquationSolverMethod::Qmr || method == GmmxxLinearEquationSolverMethod::Gmres) { // Make sure that the requested preconditioner is available if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Ilu && !iluPreconditioner) { iluPreconditioner = std::make_unique>>(*gmmxxA); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Diagonal) { diagonalPreconditioner = std::make_unique>>(*gmmxxA); } // Prepare an iteration object that determines the accuracy and the maximum number of iterations. gmm::iteration iter(storm::utility::convertNumber(env.solver().gmmxx().getPrecision()), 0, env.solver().gmmxx().getMaximalNumberOfIterations()); // Invoke gmm with the corresponding settings if (method == GmmxxLinearEquationSolverMethod::Bicgstab) { if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Ilu) { gmm::bicgstab(*gmmxxA, x, b, *iluPreconditioner, iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Diagonal) { gmm::bicgstab(*gmmxxA, x, b, *diagonalPreconditioner, iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::None) { gmm::bicgstab(*gmmxxA, x, b, gmm::identity_matrix(), iter); } } else if (method == GmmxxLinearEquationSolverMethod::Qmr) { if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Ilu) { gmm::qmr(*gmmxxA, x, b, *iluPreconditioner, iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Diagonal) { gmm::qmr(*gmmxxA, x, b, *diagonalPreconditioner, iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::None) { gmm::qmr(*gmmxxA, x, b, gmm::identity_matrix(), iter); } } else if (method == GmmxxLinearEquationSolverMethod::Gmres) { if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Ilu) { gmm::gmres(*gmmxxA, x, b, *iluPreconditioner, env.solver().gmmxx().getRestartThreshold(), iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::Diagonal) { gmm::gmres(*gmmxxA, x, b, *diagonalPreconditioner, env.solver().gmmxx().getRestartThreshold(), iter); } else if (preconditioner == GmmxxLinearEquationSolverPreconditioner::None) { gmm::gmres(*gmmxxA, x, b, gmm::identity_matrix(), env.solver().gmmxx().getRestartThreshold(), iter); } } if (!this->isCachingEnabled()) { clearCache(); } // Make sure that all results conform to the bounds. storm::utility::vector::clip(x, this->lowerBound, this->upperBound); // Check if the solver converged and issue a warning otherwise. if (iter.converged()) { STORM_LOG_INFO("Iterative solver converged after " << iter.get_iteration() << " iterations."); return true; } else { STORM_LOG_WARN("Iterative solver did not converge."); return false; } } STORM_LOG_ERROR("Selected method is not available"); return false; } template void GmmxxLinearEquationSolver::multiply(std::vector& x, std::vector const* b, std::vector& result) const { multiplier.multAdd(*gmmxxA, x, b, result); if (!this->isCachingEnabled()) { clearCache(); } } template void GmmxxLinearEquationSolver::multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices) const { multiplier.multAddReduce(dir, rowGroupIndices, *gmmxxA, x, b, result, choices); } template bool GmmxxLinearEquationSolver::supportsGaussSeidelMultiplication() const { return true; } template void GmmxxLinearEquationSolver::multiplyGaussSeidel(std::vector& x, std::vector const* b) const { multiplier.multAddGaussSeidelBackward(*gmmxxA, x, b); } template void GmmxxLinearEquationSolver::multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices) const { multiplier.multAddReduceGaussSeidel(dir, rowGroupIndices, *gmmxxA, x, b, choices); } template LinearEquationSolverProblemFormat GmmxxLinearEquationSolver::getEquationProblemFormat(Environment const& env) const { return LinearEquationSolverProblemFormat::EquationSystem; } template void GmmxxLinearEquationSolver::clearCache() const { iluPreconditioner.reset(); diagonalPreconditioner.reset(); LinearEquationSolver::clearCache(); } template uint64_t GmmxxLinearEquationSolver::getMatrixRowCount() const { return gmmxxA->nr; } template uint64_t GmmxxLinearEquationSolver::getMatrixColumnCount() const { return gmmxxA->nc; } template std::unique_ptr> GmmxxLinearEquationSolverFactory::create(Environment const& env, LinearEquationSolverTask const& task) const { return std::make_unique>(); } template std::unique_ptr> GmmxxLinearEquationSolverFactory::clone() const { return std::make_unique>(*this); } // Explicitly instantiate the solver. template class GmmxxLinearEquationSolver; template class GmmxxLinearEquationSolverFactory; } }