#include "storm/solver/StandardMinMaxLinearEquationSolver.h" #include "storm/solver/IterativeMinMaxLinearEquationSolver.h" #include "storm/solver/GmmxxLinearEquationSolver.h" #include "storm/solver/EigenLinearEquationSolver.h" #include "storm/solver/NativeLinearEquationSolver.h" #include "storm/solver/EliminationLinearEquationSolver.h" #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/utility/vector.h" #include "storm/utility/macros.h" #include "storm/exceptions/InvalidSettingsException.h" #include "storm/exceptions/InvalidStateException.h" #include "storm/exceptions/NotImplementedException.h" namespace storm { namespace solver { template StandardMinMaxLinearEquationSolver::StandardMinMaxLinearEquationSolver(std::unique_ptr>&& linearEquationSolverFactory) : linearEquationSolverFactory(std::move(linearEquationSolverFactory)), A(nullptr) { // Intentionally left empty. } template StandardMinMaxLinearEquationSolver::StandardMinMaxLinearEquationSolver(storm::storage::SparseMatrix const& A, std::unique_ptr>&& linearEquationSolverFactory) : linearEquationSolverFactory(std::move(linearEquationSolverFactory)), localA(nullptr), A(&A) { // Intentionally left empty. } template StandardMinMaxLinearEquationSolver::StandardMinMaxLinearEquationSolver(storm::storage::SparseMatrix&& A, std::unique_ptr>&& linearEquationSolverFactory) : linearEquationSolverFactory(std::move(linearEquationSolverFactory)), localA(std::make_unique>(std::move(A))), A(localA.get()) { // Intentionally left empty. } template void StandardMinMaxLinearEquationSolver::setMatrix(storm::storage::SparseMatrix const& matrix) { this->localA = nullptr; this->A = &matrix; } template void StandardMinMaxLinearEquationSolver::setMatrix(storm::storage::SparseMatrix&& matrix) { this->localA = std::make_unique>(std::move(matrix)); this->A = this->localA.get(); } template void StandardMinMaxLinearEquationSolver::repeatedMultiply(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const* b, uint_fast64_t n) const { if (!linEqSolverA) { linEqSolverA = linearEquationSolverFactory->create(env, *A, LinearEquationSolverTask::Multiply); linEqSolverA->setCachingEnabled(true); } if (!auxiliaryRowGroupVector) { auxiliaryRowGroupVector = std::make_unique>(this->A->getRowGroupCount()); } this->startMeasureProgress(); for (uint64_t i = 0; i < n; ++i) { linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, b, *auxiliaryRowGroupVector); std::swap(x, *auxiliaryRowGroupVector); // Potentially show progress. this->showProgressIterative(i, n); } if (!this->isCachingEnabled()) { clearCache(); } } template void StandardMinMaxLinearEquationSolver::clearCache() const { linEqSolverA.reset(); auxiliaryRowVector.reset(); MinMaxLinearEquationSolver::clearCache(); } template StandardMinMaxLinearEquationSolverFactory::StandardMinMaxLinearEquationSolverFactory() : MinMaxLinearEquationSolverFactory(), linearEquationSolverFactory(std::make_unique>()) { // Intentionally left empty. } template StandardMinMaxLinearEquationSolverFactory::StandardMinMaxLinearEquationSolverFactory(std::unique_ptr>&& linearEquationSolverFactory) : MinMaxLinearEquationSolverFactory(), linearEquationSolverFactory(std::move(linearEquationSolverFactory)) { // Intentionally left empty. } template StandardMinMaxLinearEquationSolverFactory::StandardMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType) : MinMaxLinearEquationSolverFactory() { switch (solverType) { case EquationSolverType::Gmmxx: linearEquationSolverFactory = std::make_unique>(); break; case EquationSolverType::Eigen: linearEquationSolverFactory = std::make_unique>(); break; case EquationSolverType::Native: linearEquationSolverFactory = std::make_unique>(); break; case EquationSolverType::Elimination: linearEquationSolverFactory = std::make_unique>(); break; } } template<> StandardMinMaxLinearEquationSolverFactory::StandardMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType) : MinMaxLinearEquationSolverFactory() { switch (solverType) { case EquationSolverType::Eigen: linearEquationSolverFactory = std::make_unique>(); break; case EquationSolverType::Elimination: linearEquationSolverFactory = std::make_unique>(); break; default: STORM_LOG_THROW(false, storm::exceptions::InvalidSettingsException, "Unsupported equation solver for this data type."); } } template std::unique_ptr> StandardMinMaxLinearEquationSolverFactory::create(Environment const& env) const { std::unique_ptr> result; auto method = env.solver().minMax().getMethod(); if (method == MinMaxMethod::ValueIteration || method == MinMaxMethod::PolicyIteration || method == MinMaxMethod::RationalSearch) { result = std::make_unique>(this->linearEquationSolverFactory->clone()); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidSettingsException, "The selected min max method is not supported by this solver."); } result->setRequirementsChecked(this->isRequirementsCheckedSet()); return result; } template GmmxxMinMaxLinearEquationSolverFactory::GmmxxMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory(EquationSolverType::Gmmxx) { // Intentionally left empty. } template EigenMinMaxLinearEquationSolverFactory::EigenMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory(EquationSolverType::Eigen) { // Intentionally left empty. } template NativeMinMaxLinearEquationSolverFactory::NativeMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory(EquationSolverType::Native) { // Intentionally left empty. } template EliminationMinMaxLinearEquationSolverFactory::EliminationMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory(EquationSolverType::Elimination) { // Intentionally left empty. } template class StandardMinMaxLinearEquationSolver; template class StandardMinMaxLinearEquationSolverFactory; template class GmmxxMinMaxLinearEquationSolverFactory; template class EigenMinMaxLinearEquationSolverFactory; template class NativeMinMaxLinearEquationSolverFactory; template class EliminationMinMaxLinearEquationSolverFactory; #ifdef STORM_HAVE_CARL template class StandardMinMaxLinearEquationSolver; template class StandardMinMaxLinearEquationSolverFactory; template class EigenMinMaxLinearEquationSolverFactory; template class EliminationMinMaxLinearEquationSolverFactory; #endif } }