From f89236100b9f676e497209c7c0ff9583df94cdbf Mon Sep 17 00:00:00 2001 From: TimQu Date: Sat, 9 Dec 2017 15:21:46 +0100 Subject: [PATCH] Added topological linear equation solver --- src/storm/environment/SubEnvironment.cpp | 2 + .../environment/solver/SolverEnvironment.cpp | 19 +- .../environment/solver/SolverEnvironment.h | 4 + ...logicalLinearEquationSolverEnvironment.cpp | 38 ++ ...pologicalLinearEquationSolverEnvironment.h | 24 ++ src/storm/settings/modules/CoreSettings.cpp | 4 +- src/storm/solver/LinearEquationSolver.cpp | 6 +- src/storm/solver/SolverSelectionOptions.cpp | 2 + src/storm/solver/SolverSelectionOptions.h | 2 +- .../TopologicalLinearEquationSolver.cpp | 344 ++++++++++++++++++ .../solver/TopologicalLinearEquationSolver.h | 84 +++++ 11 files changed, 522 insertions(+), 7 deletions(-) create mode 100644 src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp create mode 100644 src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h create mode 100644 src/storm/solver/TopologicalLinearEquationSolver.cpp create mode 100644 src/storm/solver/TopologicalLinearEquationSolver.h diff --git a/src/storm/environment/SubEnvironment.cpp b/src/storm/environment/SubEnvironment.cpp index 698300f03..6227548e5 100644 --- a/src/storm/environment/SubEnvironment.cpp +++ b/src/storm/environment/SubEnvironment.cpp @@ -5,6 +5,7 @@ #include "storm/environment/solver/NativeSolverEnvironment.h" #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/environment/solver/GameSolverEnvironment.h" +#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h" namespace storm { @@ -40,6 +41,7 @@ namespace storm { template class SubEnvironment; template class SubEnvironment; template class SubEnvironment; + template class SubEnvironment; } diff --git a/src/storm/environment/solver/SolverEnvironment.cpp b/src/storm/environment/solver/SolverEnvironment.cpp index 8371b8158..6e0eb49c0 100644 --- a/src/storm/environment/solver/SolverEnvironment.cpp +++ b/src/storm/environment/solver/SolverEnvironment.cpp @@ -5,6 +5,7 @@ #include "storm/environment/solver/GmmxxSolverEnvironment.h" #include "storm/environment/solver/NativeSolverEnvironment.h" #include "storm/environment/solver/GameSolverEnvironment.h" +#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h" #include "storm/settings/SettingsManager.h" #include "storm/settings/modules/GeneralSettings.h" @@ -66,6 +67,14 @@ namespace storm { return gameSolverEnvironment.get(); } + TopologicalLinearEquationSolverEnvironment& SolverEnvironment::topological() { + return topologicalSolverEnvironment.get(); + } + + TopologicalLinearEquationSolverEnvironment const& SolverEnvironment::topological() const { + return topologicalSolverEnvironment.get(); + } + bool SolverEnvironment::isForceSoundness() const { return forceSoundness; } @@ -106,22 +115,24 @@ namespace storm { STORM_LOG_ASSERT(getLinearEquationSolverType() == storm::solver::EquationSolverType::Native || getLinearEquationSolverType() == storm::solver::EquationSolverType::Gmmxx || getLinearEquationSolverType() == storm::solver::EquationSolverType::Eigen || - getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination, + getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination || + getLinearEquationSolverType() == storm::solver::EquationSolverType::Topological, "The current solver type is not respected in this method."); native().setPrecision(value); gmmxx().setPrecision(value); eigen().setPrecision(value); - // Elimination solver does not have a precision + // Elimination and Topological solver do not have a precision } void SolverEnvironment::setLinearEquationSolverRelativeTerminationCriterion(bool value) { STORM_LOG_ASSERT(getLinearEquationSolverType() == storm::solver::EquationSolverType::Native || getLinearEquationSolverType() == storm::solver::EquationSolverType::Gmmxx || getLinearEquationSolverType() == storm::solver::EquationSolverType::Eigen || - getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination, + getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination || + getLinearEquationSolverType() == storm::solver::EquationSolverType::Topological, "The current solver type is not respected in this method."); native().setRelativeTerminationCriterion(value); - // Elimination, gmm and eigen solver do not have an option for relative termination criterion + // Elimination, gmm, eigen, and topological solver do not have an option for relative termination criterion } diff --git a/src/storm/environment/solver/SolverEnvironment.h b/src/storm/environment/solver/SolverEnvironment.h index ec591beb2..7f3260291 100644 --- a/src/storm/environment/solver/SolverEnvironment.h +++ b/src/storm/environment/solver/SolverEnvironment.h @@ -16,6 +16,7 @@ namespace storm { class NativeSolverEnvironment; class MinMaxSolverEnvironment; class GameSolverEnvironment; + class TopologicalLinearEquationSolverEnvironment; class SolverEnvironment { public: @@ -33,6 +34,8 @@ namespace storm { MinMaxSolverEnvironment const& minMax() const; GameSolverEnvironment& game(); GameSolverEnvironment const& game() const; + TopologicalLinearEquationSolverEnvironment& topological(); + TopologicalLinearEquationSolverEnvironment const& topological() const; bool isForceSoundness() const; void setForceSoundness(bool value); @@ -50,6 +53,7 @@ namespace storm { SubEnvironment gmmxxSolverEnvironment; SubEnvironment nativeSolverEnvironment; SubEnvironment gameSolverEnvironment; + SubEnvironment topologicalSolverEnvironment; SubEnvironment minMaxSolverEnvironment; storm::solver::EquationSolverType linearEquationSolverType; diff --git a/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp b/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp new file mode 100644 index 000000000..e49c0c15f --- /dev/null +++ b/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp @@ -0,0 +1,38 @@ +#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h" + +#include "storm/settings/SettingsManager.h" +#include "storm/settings/modules/GameSolverSettings.h" +#include "storm/utility/macros.h" + +#include "storm/exceptions/InvalidArgumentException.h" + +namespace storm { + + TopologicalLinearEquationSolverEnvironment::TopologicalLinearEquationSolverEnvironment() { + auto const& topologicalSettings = storm::settings::getModule(); + std::cout << "TODO: get actual settings in topo environment." << std::endl; + underlyingSolverType = storm::solver::EquationSolverType::Native; + underlyingSolverTypeSetFromDefault = true; + } + + TopologicalLinearEquationSolverEnvironment::~TopologicalLinearEquationSolverEnvironment() { + // Intentionally left empty + } + + storm::solver::EquationSolverType const& TopologicalLinearEquationSolverEnvironment::getUnderlyingSolverType() const { + return underlyingSolverType; + } + + bool const& TopologicalLinearEquationSolverEnvironment::isUnderlyingSolverTypeSetFromDefault() const { + return underlyingSolverTypeSetFromDefault; + } + + void TopologicalLinearEquationSolverEnvironment::setUnderlyingSolverType(storm::solver::EquationSolverType value) { + STORM_LOG_THROW(value != storm::solver::EquationSolverType::Topological, storm::exceptions::InvalidArgumentException, "Can not use the topological solver as underlying solver of the topological solver."); + underlyingSolverTypeSetFromDefault = false; + underlyingSolverType = value; + } + + + +} diff --git a/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h b/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h new file mode 100644 index 000000000..b79ed66a9 --- /dev/null +++ b/src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h @@ -0,0 +1,24 @@ +#pragma once + +#include "storm/environment/solver/SolverEnvironment.h" + +#include "storm/solver/SolverSelectionOptions.h" + +namespace storm { + + class TopologicalLinearEquationSolverEnvironment { + public: + + TopologicalLinearEquationSolverEnvironment(); + ~TopologicalLinearEquationSolverEnvironment(); + + storm::solver::EquationSolverType const& getUnderlyingSolverType() const; + bool const& isUnderlyingSolverTypeSetFromDefault() const; + void setUnderlyingSolverType(storm::solver::EquationSolverType value); + + private: + storm::solver::EquationSolverType underlyingSolverType; + bool underlyingSolverTypeSetFromDefault; + }; +} + diff --git a/src/storm/settings/modules/CoreSettings.cpp b/src/storm/settings/modules/CoreSettings.cpp index b01aee3f7..9bcd4d77c 100644 --- a/src/storm/settings/modules/CoreSettings.cpp +++ b/src/storm/settings/modules/CoreSettings.cpp @@ -43,7 +43,7 @@ namespace storm { this->addOption(storm::settings::OptionBuilder(moduleName, engineOptionName, false, "Sets which engine is used for model building and model checking.").setShortName(engineOptionShortName) .addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the engine to use.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(engines)).setDefaultValueString("sparse").build()).build()); - std::vector linearEquationSolver = {"gmm++", "native", "eigen", "elimination"}; + std::vector linearEquationSolver = {"gmm++", "native", "eigen", "elimination", "topological"}; this->addOption(storm::settings::OptionBuilder(moduleName, eqSolverOptionName, false, "Sets which solver is preferred for solving systems of linear equations.") .addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the solver to prefer.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(linearEquationSolver)).setDefaultValueString("gmm++").build()).build()); @@ -86,6 +86,8 @@ namespace storm { return storm::solver::EquationSolverType::Eigen; } else if (equationSolverName == "elimination") { return storm::solver::EquationSolverType::Elimination; + } else if (equationSolverName == "topological") { + return storm::solver::EquationSolverType::Topological; } STORM_LOG_THROW(false, storm::exceptions::IllegalArgumentValueException, "Unknown equation solver '" << equationSolverName << "'."); } diff --git a/src/storm/solver/LinearEquationSolver.cpp b/src/storm/solver/LinearEquationSolver.cpp index caff69756..193f2cc09 100644 --- a/src/storm/solver/LinearEquationSolver.cpp +++ b/src/storm/solver/LinearEquationSolver.cpp @@ -7,6 +7,7 @@ #include "storm/solver/NativeLinearEquationSolver.h" #include "storm/solver/EigenLinearEquationSolver.h" #include "storm/solver/EliminationLinearEquationSolver.h" +#include "storm/solver/TopologicalLinearEquationSolver.h" #include "storm/utility/vector.h" @@ -179,6 +180,7 @@ namespace storm { case EquationSolverType::Native: return std::make_unique>(); case EquationSolverType::Eigen: return std::make_unique>(); case EquationSolverType::Elimination: return std::make_unique>(); + case EquationSolverType::Topological: return std::make_unique>(); default: STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type."); return nullptr; @@ -198,6 +200,7 @@ namespace storm { switch (type) { case EquationSolverType::Eigen: return std::make_unique>(); case EquationSolverType::Elimination: return std::make_unique>(); + case EquationSolverType::Topological: return std::make_unique>(); default: STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type."); return nullptr; @@ -209,7 +212,7 @@ namespace storm { EquationSolverType type = env.solver().getLinearEquationSolverType(); // Adjust the solver type if none was specified and we want sound computations - if (env.solver().isForceSoundness() && task != LinearEquationSolverTask::Multiply && type != EquationSolverType::Native && type != EquationSolverType::Eigen && type != EquationSolverType::Elimination) { + if (env.solver().isForceSoundness() && task != LinearEquationSolverTask::Multiply && type != EquationSolverType::Native && type != EquationSolverType::Eigen && type != EquationSolverType::Elimination && type != EquationSolverType::Topological) { if (env.solver().isLinearEquationSolverTypeSetFromDefaultValue()) { type = EquationSolverType::Native; STORM_LOG_INFO("Selecting '" + toString(type) + "' as the linear equation solver to guarantee sound results. If you want to override this, please explicitly specify a different solver."); @@ -223,6 +226,7 @@ namespace storm { case EquationSolverType::Native: return std::make_unique>(); case EquationSolverType::Eigen: return std::make_unique>(); case EquationSolverType::Elimination: return std::make_unique>(); + case EquationSolverType::Topological: return std::make_unique>(); default: STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type."); return nullptr; diff --git a/src/storm/solver/SolverSelectionOptions.cpp b/src/storm/solver/SolverSelectionOptions.cpp index 0393e6912..7b8e19c65 100644 --- a/src/storm/solver/SolverSelectionOptions.cpp +++ b/src/storm/solver/SolverSelectionOptions.cpp @@ -50,6 +50,8 @@ namespace storm { return "Eigen"; case EquationSolverType::Elimination: return "Elimination"; + case EquationSolverType::Topological: + return "Topological"; } return "invalid"; } diff --git a/src/storm/solver/SolverSelectionOptions.h b/src/storm/solver/SolverSelectionOptions.h index e5c7a55eb..446e173bd 100644 --- a/src/storm/solver/SolverSelectionOptions.h +++ b/src/storm/solver/SolverSelectionOptions.h @@ -11,7 +11,7 @@ namespace storm { ExtendEnumsWithSelectionField(LraMethod, LinearProgramming, ValueIteration) ExtendEnumsWithSelectionField(LpSolverType, Gurobi, Glpk, Z3) - ExtendEnumsWithSelectionField(EquationSolverType, Native, Gmmxx, Eigen, Elimination) + ExtendEnumsWithSelectionField(EquationSolverType, Native, Gmmxx, Eigen, Elimination, Topological) ExtendEnumsWithSelectionField(SmtSolverType, Z3, Mathsat) ExtendEnumsWithSelectionField(NativeLinearEquationSolverMethod, Jacobi, GaussSeidel, SOR, WalkerChae, Power, RationalSearch, QuickPower) diff --git a/src/storm/solver/TopologicalLinearEquationSolver.cpp b/src/storm/solver/TopologicalLinearEquationSolver.cpp new file mode 100644 index 000000000..c7ebee971 --- /dev/null +++ b/src/storm/solver/TopologicalLinearEquationSolver.cpp @@ -0,0 +1,344 @@ +#include "storm/solver/TopologicalLinearEquationSolver.h" + +#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h" + +#include "storm/utility/constants.h" +#include "storm/utility/vector.h" +#include "storm/exceptions/InvalidStateException.h" +#include "storm/exceptions/InvalidEnvironmentException.h" +#include "storm/exceptions/UnexpectedException.h" + +namespace storm { + namespace solver { + + template + TopologicalLinearEquationSolver::TopologicalLinearEquationSolver() : localA(nullptr), A(nullptr) { + // Intentionally left empty. + } + + template + TopologicalLinearEquationSolver::TopologicalLinearEquationSolver(storm::storage::SparseMatrix const& A) : localA(nullptr), A(nullptr) { + this->setMatrix(A); + } + + template + TopologicalLinearEquationSolver::TopologicalLinearEquationSolver(storm::storage::SparseMatrix&& A) : localA(nullptr), A(nullptr) { + this->setMatrix(std::move(A)); + } + + template + void TopologicalLinearEquationSolver::setMatrix(storm::storage::SparseMatrix const& A) { + localA.reset(); + this->A = &A; + clearCache(); + } + + template + void TopologicalLinearEquationSolver::setMatrix(storm::storage::SparseMatrix&& A) { + localA = std::make_unique>(std::move(A)); + this->A = localA.get(); + clearCache(); + } + + template + storm::Environment TopologicalLinearEquationSolver::getEnvironmentForUnderlyingSolver(storm::Environment const& env) const { + storm::Environment subEnv(env); + subEnv.solver().setLinearEquationSolverType(env.solver().topological().getUnderlyingSolverType()); + return subEnv; + } + + template + bool TopologicalLinearEquationSolver::internalSolveEquations(Environment const& env, std::vector& x, std::vector const& b) const { + + if (!this->sortedSccDecomposition) { + STORM_LOG_TRACE("Creating SCC decomposition."); + createSortedSccDecomposition(); + } + + storm::Environment sccSolverEnvironment = getEnvironmentForUnderlyingSolver(env); + + // Handle the case where there is just one large SCC + if (this->sortedSccDecomposition->size() == 1) { + return solveFullyConnectedEquationSystem(sccSolverEnvironment, x, b); + } + + storm::storage::BitVector sccAsBitVector(x.size(), false); + bool returnValue = true; + for (auto const& scc : *this->sortedSccDecomposition) { + if (scc.isTrivial()) { + returnValue = returnValue && solveTrivialScc(*scc.begin(), x, b); + } else { + sccAsBitVector.clear(); + for (auto const& state : scc) { + sccAsBitVector.set(state, true); + } + returnValue = returnValue && solveScc(sccSolverEnvironment, sccAsBitVector, x, b); + } + } + + if (!this->isCachingEnabled()) { + clearCache(); + } + + return returnValue; + } + + template + void TopologicalLinearEquationSolver::createSortedSccDecomposition() const { + // Obtain the scc decomposition + auto sccDecomposition = storm::storage::StronglyConnectedComponentDecomposition(*this->A); + + // Get a mapping from matrix row to the corresponding scc + STORM_LOG_THROW(sccDecomposition.size() < std::numeric_limits::max(), storm::exceptions::UnexpectedException, "The number of SCCs is too large."); + std::vector sccIndices(this->A->getRowCount(), std::numeric_limits::max()); + uint32_t sccIndex = 0; + for (auto const& scc : sccDecomposition) { + for (auto const& row : scc) { + sccIndices[row] = sccIndex; + } + ++sccIndex; + } + + // Prepare the resulting set of sorted sccs + this->sortedSccDecomposition = std::make_unique>(); + std::vector& sortedSCCs = *this->sortedSccDecomposition; + sortedSCCs.reserve(sccDecomposition.size()); + + // Find a topological sort via DFS. + storm::storage::BitVector unsortedSCCs(sccDecomposition.size(), true); + std::vector sccStack; + uint32_t const token = std::numeric_limits::max(); + std::set successorSCCs; + + for (uint32_t firstUnsortedScc = 0; firstUnsortedScc < unsortedSCCs.size(); firstUnsortedScc = unsortedSCCs.getNextSetIndex(firstUnsortedScc + 1)) { + + sccStack.push_back(firstUnsortedScc); + while (!sccStack.empty()) { + auto const& currentSccIndex = sccStack.back(); + if (currentSccIndex != token) { + // Check whether the SCC is still unprocessed + if (unsortedSCCs.get(currentSccIndex)) { + // Explore the successors of the scc. + storm::storage::StronglyConnectedComponent const& currentScc = sccDecomposition.getBlock(currentSccIndex); + // We first push a token on the stack in order to recognize later when all successors of this SCC have been explored already. + sccStack.push_back(token); + // Now add all successors that are not already sorted. + // Successors should only be added once, so we first prepare a set of them and add them afterwards. + successorSCCs.clear(); + for (auto const& row : currentScc) { + for (auto const& entry : this->A->getRow(row)) { + auto const& successorSCC = sccIndices[entry.getColumn()]; + if (successorSCC != currentSccIndex && unsortedSCCs.get(successorSCC)) { + successorSCCs.insert(successorSCC); + } + } + } + sccStack.insert(sccStack.end(), successorSCCs.begin(), successorSCCs.end()); + + } + } else { + // all successors of the current scc have already been explored. + sccStack.pop_back(); // pop the token + sortedSCCs.push_back(std::move(sccDecomposition.getBlock(sccStack.back()))); + unsortedSCCs.set(sccStack.back(), false); + sccStack.pop_back(); // pop the current scc index + } + } + } + } + + template + bool TopologicalLinearEquationSolver::solveTrivialScc(uint64_t const& sccState, std::vector& globalX, std::vector const& globalB) const { + ValueType& xi = globalX[sccState]; + xi = globalB[sccState]; + bool hasDiagonalEntry = false; + ValueType denominator; + for (auto const& entry : this->A->getRow(sccState)) { + if (entry.getColumn() == sccState) { + STORM_LOG_ASSERT(!storm::utility::isOne(entry.getValue()), "Diagonal entry of fix point system has value 1."); + hasDiagonalEntry = true; + denominator = storm::utility::one() - entry.getValue(); + } else { + xi += entry.getValue() * globalX[entry.getColumn()]; + } + } + + if (hasDiagonalEntry) { + xi /= denominator; + } + return true; + } + + template + bool TopologicalLinearEquationSolver::solveFullyConnectedEquationSystem(storm::Environment const& sccSolverEnvironment, std::vector& x, std::vector const& b) const { + if (!this->sccSolver) { + this->sccSolver = GeneralLinearEquationSolverFactory().create(sccSolverEnvironment, LinearEquationSolverTask::SolveEquations); + this->sccSolver->setCachingEnabled(true); + this->sccSolver->setBoundsFromOtherSolver(*this); + if (this->sccSolver->getEquationProblemFormat(sccSolverEnvironment) == LinearEquationSolverProblemFormat::EquationSystem) { + // Convert the matrix to an equation system. Note that we need to insert diagonal entries. + storm::storage::SparseMatrix eqSysA(*this->A, true); + eqSysA.convertToEquationSystem(); + this->sccSolver->setMatrix(std::move(eqSysA)); + } else { + this->sccSolver->setMatrix(*this->A); + } + } + return this->sccSolver->solveEquations(sccSolverEnvironment, x, b); + } + + template + bool TopologicalLinearEquationSolver::solveScc(storm::Environment const& sccSolverEnvironment, storm::storage::BitVector const& scc, std::vector& globalX, std::vector const& globalB) const { + + // Set up the SCC solver + if (!this->sccSolver) { + this->sccSolver = GeneralLinearEquationSolverFactory().create(sccSolverEnvironment, LinearEquationSolverTask::SolveEquations); + this->sccSolver->setCachingEnabled(true); + } + + // Matrix + bool asEquationSystem = this->sccSolver->getEquationProblemFormat(sccSolverEnvironment) == LinearEquationSolverProblemFormat::EquationSystem; + storm::storage::SparseMatrix sccA = this->A->getSubmatrix(true, scc, scc, asEquationSystem); + if (asEquationSystem) { + sccA.convertToEquationSystem(); + } + this->sccSolver->setMatrix(std::move(sccA)); + + // x Vector + auto sccX = storm::utility::vector::filterVector(globalX, scc); + + // b Vector + std::vector sccB; + sccB.reserve(scc.getNumberOfSetBits()); + for (auto const& row : scc) { + ValueType bi = globalB[row]; + for (auto const& entry : this->A->getRow(row)) { + if (!scc.get(entry.getColumn())) { + bi += entry.getValue() * globalX[entry.getColumn()]; + } + } + sccB.push_back(std::move(bi)); + } + + // lower/upper bounds + if (this->hasLowerBound(storm::solver::AbstractEquationSolver::BoundType::Global)) { + this->sccSolver->setLowerBound(this->getLowerBound()); + } else if (this->hasLowerBound(storm::solver::AbstractEquationSolver::BoundType::Local)) { + this->sccSolver->setLowerBounds(storm::utility::vector::filterVector(this->getLowerBounds(), scc)); + } + if (this->hasUpperBound(storm::solver::AbstractEquationSolver::BoundType::Global)) { + this->sccSolver->setUpperBound(this->getUpperBound()); + } else if (this->hasUpperBound(storm::solver::AbstractEquationSolver::BoundType::Local)) { + this->sccSolver->setUpperBounds(storm::utility::vector::filterVector(this->getUpperBounds(), scc)); + } + + return this->sccSolver->solveEquations(sccSolverEnvironment, sccX, sccB); + } + + + template + void TopologicalLinearEquationSolver::multiply(std::vector& x, std::vector const* b, std::vector& result) const { + if (&x != &result) { + multiplier.multAdd(*A, x, b, result); + } else { + // If the two vectors are aliases, we need to create a temporary. + if (!this->cachedRowVector) { + this->cachedRowVector = std::make_unique>(getMatrixRowCount()); + } + + multiplier.multAdd(*A, x, b, *this->cachedRowVector); + result.swap(*this->cachedRowVector); + + if (!this->isCachingEnabled()) { + clearCache(); + } + } + } + + template + void TopologicalLinearEquationSolver::multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices) const { + if (&x != &result) { + multiplier.multAddReduce(dir, rowGroupIndices, *A, x, b, result, choices); + } else { + // If the two vectors are aliases, we need to create a temporary. + if (!this->cachedRowVector) { + this->cachedRowVector = std::make_unique>(getMatrixRowCount()); + } + + multiplier.multAddReduce(dir, rowGroupIndices, *A, x, b, *this->cachedRowVector, choices); + result.swap(*this->cachedRowVector); + + if (!this->isCachingEnabled()) { + clearCache(); + } + } + } + + template + bool TopologicalLinearEquationSolver::supportsGaussSeidelMultiplication() const { + return true; + } + + template + void TopologicalLinearEquationSolver::multiplyGaussSeidel(std::vector& x, std::vector const* b) const { + STORM_LOG_ASSERT(this->A->getRowCount() == this->A->getColumnCount(), "This function is only applicable for square matrices."); + multiplier.multAddGaussSeidelBackward(*A, x, b); + } + + template + void TopologicalLinearEquationSolver::multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices) const { + multiplier.multAddReduceGaussSeidelBackward(dir, rowGroupIndices, *A, x, b, choices); + } + + template + LinearEquationSolverProblemFormat TopologicalLinearEquationSolver::getEquationProblemFormat(Environment const& env) const { + return LinearEquationSolverProblemFormat::FixedPointSystem; + } + + template + LinearEquationSolverRequirements TopologicalLinearEquationSolver::getRequirements(Environment const& env, LinearEquationSolverTask const& task) const { + // Return the requirements of the underlying solver + return GeneralLinearEquationSolverFactory().getRequirements(getEnvironmentForUnderlyingSolver(env), task); + } + + template + void TopologicalLinearEquationSolver::clearCache() const { + sortedSccDecomposition.reset(); + sccSolver.reset(); + LinearEquationSolver::clearCache(); + } + + template + uint64_t TopologicalLinearEquationSolver::getMatrixRowCount() const { + return this->A->getRowCount(); + } + + template + uint64_t TopologicalLinearEquationSolver::getMatrixColumnCount() const { + return this->A->getColumnCount(); + } + + template + std::unique_ptr> TopologicalLinearEquationSolverFactory::create(Environment const& env, LinearEquationSolverTask const& task) const { + return std::make_unique>(); + } + + template + std::unique_ptr> TopologicalLinearEquationSolverFactory::clone() const { + return std::make_unique>(*this); + } + + // Explicitly instantiate the linear equation solver. + template class TopologicalLinearEquationSolver; + template class TopologicalLinearEquationSolverFactory; + +#ifdef STORM_HAVE_CARL + template class TopologicalLinearEquationSolver; + template class TopologicalLinearEquationSolverFactory; + + template class TopologicalLinearEquationSolver; + template class TopologicalLinearEquationSolverFactory; + +#endif + } +} diff --git a/src/storm/solver/TopologicalLinearEquationSolver.h b/src/storm/solver/TopologicalLinearEquationSolver.h new file mode 100644 index 000000000..a3d25c1c1 --- /dev/null +++ b/src/storm/solver/TopologicalLinearEquationSolver.h @@ -0,0 +1,84 @@ +#pragma once + +#include "storm/solver/LinearEquationSolver.h" + +#include "storm/solver/SolverSelectionOptions.h" +#include "storm/solver/NativeMultiplier.h" +#include "storm/storage/StronglyConnectedComponentDecomposition.h" + +namespace storm { + + class Environment; + + namespace solver { + + template + class TopologicalLinearEquationSolver : public LinearEquationSolver { + public: + TopologicalLinearEquationSolver(); + TopologicalLinearEquationSolver(storm::storage::SparseMatrix const& A); + TopologicalLinearEquationSolver(storm::storage::SparseMatrix&& A); + + virtual void setMatrix(storm::storage::SparseMatrix const& A) override; + virtual void setMatrix(storm::storage::SparseMatrix&& A) override; + + virtual void multiply(std::vector& x, std::vector const* b, std::vector& result) const override; + virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) const override; + virtual bool supportsGaussSeidelMultiplication() const override; + virtual void multiplyGaussSeidel(std::vector& x, std::vector const* b) const override; + virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices = nullptr) const override; + + virtual LinearEquationSolverProblemFormat getEquationProblemFormat(storm::Environment const& env) const override; + virtual LinearEquationSolverRequirements getRequirements(Environment const& env, LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) const override; + + virtual void clearCache() const override; + + protected: + virtual bool internalSolveEquations(storm::Environment const& env, std::vector& x, std::vector const& b) const override; + + private: + + virtual uint64_t getMatrixRowCount() const override; + virtual uint64_t getMatrixColumnCount() const override; + + storm::Environment getEnvironmentForUnderlyingSolver(storm::Environment const& env) const; + + // Creates an SCC decomposition and sorts the SCCs according to a topological sort. + void createSortedSccDecomposition() const; + + // Solves the SCC with the given index + // ... for the case that the SCC is trivial + bool solveTrivialScc(uint64_t const& sccState, std::vector& globalX, std::vector const& globalB) const; + // ... for the case that there is just one large SCC + bool solveFullyConnectedEquationSystem(storm::Environment const& sccSolverEnvironment, std::vector& x, std::vector const& b) const; + // ... for the remaining cases (1 < scc.size() < x.size()) + bool solveScc(storm::Environment const& sccSolverEnvironment, storm::storage::BitVector const& scc, std::vector& globalX, std::vector const& globalB) const; + + // If the solver takes posession of the matrix, we store the moved matrix in this member, so it gets deleted + // when the solver is destructed. + std::unique_ptr> localA; + + // A pointer to the original sparse matrix given to this solver. If the solver takes posession of the matrix + // the pointer refers to localA. + storm::storage::SparseMatrix const* A; + + // An object to dispatch all multiplication operations. + NativeMultiplier multiplier; + + // cached auxiliary data + mutable std::unique_ptr> sortedSccDecomposition; + mutable std::unique_ptr> sccSolver; + }; + + template + class TopologicalLinearEquationSolverFactory : public LinearEquationSolverFactory { + public: + using LinearEquationSolverFactory::create; + + virtual std::unique_ptr> create(Environment const& env, LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) const override; + + virtual std::unique_ptr> clone() const override; + + }; + } +}