Browse Source

Added topological linear equation solver

tempestpy_adaptions
TimQu 7 years ago
parent
commit
f89236100b
  1. 2
      src/storm/environment/SubEnvironment.cpp
  2. 19
      src/storm/environment/solver/SolverEnvironment.cpp
  3. 4
      src/storm/environment/solver/SolverEnvironment.h
  4. 38
      src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp
  5. 24
      src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h
  6. 4
      src/storm/settings/modules/CoreSettings.cpp
  7. 6
      src/storm/solver/LinearEquationSolver.cpp
  8. 2
      src/storm/solver/SolverSelectionOptions.cpp
  9. 2
      src/storm/solver/SolverSelectionOptions.h
  10. 344
      src/storm/solver/TopologicalLinearEquationSolver.cpp
  11. 84
      src/storm/solver/TopologicalLinearEquationSolver.h

2
src/storm/environment/SubEnvironment.cpp

@ -5,6 +5,7 @@
#include "storm/environment/solver/NativeSolverEnvironment.h"
#include "storm/environment/solver/MinMaxSolverEnvironment.h"
#include "storm/environment/solver/GameSolverEnvironment.h"
#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h"
namespace storm {
@ -40,6 +41,7 @@ namespace storm {
template class SubEnvironment<NativeSolverEnvironment>;
template class SubEnvironment<MinMaxSolverEnvironment>;
template class SubEnvironment<GameSolverEnvironment>;
template class SubEnvironment<TopologicalLinearEquationSolverEnvironment>;
}

19
src/storm/environment/solver/SolverEnvironment.cpp

@ -5,6 +5,7 @@
#include "storm/environment/solver/GmmxxSolverEnvironment.h"
#include "storm/environment/solver/NativeSolverEnvironment.h"
#include "storm/environment/solver/GameSolverEnvironment.h"
#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h"
#include "storm/settings/SettingsManager.h"
#include "storm/settings/modules/GeneralSettings.h"
@ -66,6 +67,14 @@ namespace storm {
return gameSolverEnvironment.get();
}
TopologicalLinearEquationSolverEnvironment& SolverEnvironment::topological() {
return topologicalSolverEnvironment.get();
}
TopologicalLinearEquationSolverEnvironment const& SolverEnvironment::topological() const {
return topologicalSolverEnvironment.get();
}
bool SolverEnvironment::isForceSoundness() const {
return forceSoundness;
}
@ -106,22 +115,24 @@ namespace storm {
STORM_LOG_ASSERT(getLinearEquationSolverType() == storm::solver::EquationSolverType::Native ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Gmmxx ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Eigen ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination,
getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Topological,
"The current solver type is not respected in this method.");
native().setPrecision(value);
gmmxx().setPrecision(value);
eigen().setPrecision(value);
// Elimination solver does not have a precision
// Elimination and Topological solver do not have a precision
}
void SolverEnvironment::setLinearEquationSolverRelativeTerminationCriterion(bool value) {
STORM_LOG_ASSERT(getLinearEquationSolverType() == storm::solver::EquationSolverType::Native ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Gmmxx ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Eigen ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination,
getLinearEquationSolverType() == storm::solver::EquationSolverType::Elimination ||
getLinearEquationSolverType() == storm::solver::EquationSolverType::Topological,
"The current solver type is not respected in this method.");
native().setRelativeTerminationCriterion(value);
// Elimination, gmm and eigen solver do not have an option for relative termination criterion
// Elimination, gmm, eigen, and topological solver do not have an option for relative termination criterion
}

4
src/storm/environment/solver/SolverEnvironment.h

@ -16,6 +16,7 @@ namespace storm {
class NativeSolverEnvironment;
class MinMaxSolverEnvironment;
class GameSolverEnvironment;
class TopologicalLinearEquationSolverEnvironment;
class SolverEnvironment {
public:
@ -33,6 +34,8 @@ namespace storm {
MinMaxSolverEnvironment const& minMax() const;
GameSolverEnvironment& game();
GameSolverEnvironment const& game() const;
TopologicalLinearEquationSolverEnvironment& topological();
TopologicalLinearEquationSolverEnvironment const& topological() const;
bool isForceSoundness() const;
void setForceSoundness(bool value);
@ -50,6 +53,7 @@ namespace storm {
SubEnvironment<GmmxxSolverEnvironment> gmmxxSolverEnvironment;
SubEnvironment<NativeSolverEnvironment> nativeSolverEnvironment;
SubEnvironment<GameSolverEnvironment> gameSolverEnvironment;
SubEnvironment<TopologicalLinearEquationSolverEnvironment> topologicalSolverEnvironment;
SubEnvironment<MinMaxSolverEnvironment> minMaxSolverEnvironment;
storm::solver::EquationSolverType linearEquationSolverType;

38
src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.cpp

@ -0,0 +1,38 @@
#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h"
#include "storm/settings/SettingsManager.h"
#include "storm/settings/modules/GameSolverSettings.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm {
TopologicalLinearEquationSolverEnvironment::TopologicalLinearEquationSolverEnvironment() {
auto const& topologicalSettings = storm::settings::getModule<storm::settings::modules::GameSolverSettings>();
std::cout << "TODO: get actual settings in topo environment." << std::endl;
underlyingSolverType = storm::solver::EquationSolverType::Native;
underlyingSolverTypeSetFromDefault = true;
}
TopologicalLinearEquationSolverEnvironment::~TopologicalLinearEquationSolverEnvironment() {
// Intentionally left empty
}
storm::solver::EquationSolverType const& TopologicalLinearEquationSolverEnvironment::getUnderlyingSolverType() const {
return underlyingSolverType;
}
bool const& TopologicalLinearEquationSolverEnvironment::isUnderlyingSolverTypeSetFromDefault() const {
return underlyingSolverTypeSetFromDefault;
}
void TopologicalLinearEquationSolverEnvironment::setUnderlyingSolverType(storm::solver::EquationSolverType value) {
STORM_LOG_THROW(value != storm::solver::EquationSolverType::Topological, storm::exceptions::InvalidArgumentException, "Can not use the topological solver as underlying solver of the topological solver.");
underlyingSolverTypeSetFromDefault = false;
underlyingSolverType = value;
}
}

24
src/storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h

@ -0,0 +1,24 @@
#pragma once
#include "storm/environment/solver/SolverEnvironment.h"
#include "storm/solver/SolverSelectionOptions.h"
namespace storm {
class TopologicalLinearEquationSolverEnvironment {
public:
TopologicalLinearEquationSolverEnvironment();
~TopologicalLinearEquationSolverEnvironment();
storm::solver::EquationSolverType const& getUnderlyingSolverType() const;
bool const& isUnderlyingSolverTypeSetFromDefault() const;
void setUnderlyingSolverType(storm::solver::EquationSolverType value);
private:
storm::solver::EquationSolverType underlyingSolverType;
bool underlyingSolverTypeSetFromDefault;
};
}

4
src/storm/settings/modules/CoreSettings.cpp

@ -43,7 +43,7 @@ namespace storm {
this->addOption(storm::settings::OptionBuilder(moduleName, engineOptionName, false, "Sets which engine is used for model building and model checking.").setShortName(engineOptionShortName)
.addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the engine to use.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(engines)).setDefaultValueString("sparse").build()).build());
std::vector<std::string> linearEquationSolver = {"gmm++", "native", "eigen", "elimination"};
std::vector<std::string> linearEquationSolver = {"gmm++", "native", "eigen", "elimination", "topological"};
this->addOption(storm::settings::OptionBuilder(moduleName, eqSolverOptionName, false, "Sets which solver is preferred for solving systems of linear equations.")
.addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the solver to prefer.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(linearEquationSolver)).setDefaultValueString("gmm++").build()).build());
@ -86,6 +86,8 @@ namespace storm {
return storm::solver::EquationSolverType::Eigen;
} else if (equationSolverName == "elimination") {
return storm::solver::EquationSolverType::Elimination;
} else if (equationSolverName == "topological") {
return storm::solver::EquationSolverType::Topological;
}
STORM_LOG_THROW(false, storm::exceptions::IllegalArgumentValueException, "Unknown equation solver '" << equationSolverName << "'.");
}

6
src/storm/solver/LinearEquationSolver.cpp

@ -7,6 +7,7 @@
#include "storm/solver/NativeLinearEquationSolver.h"
#include "storm/solver/EigenLinearEquationSolver.h"
#include "storm/solver/EliminationLinearEquationSolver.h"
#include "storm/solver/TopologicalLinearEquationSolver.h"
#include "storm/utility/vector.h"
@ -179,6 +180,7 @@ namespace storm {
case EquationSolverType::Native: return std::make_unique<NativeLinearEquationSolver<storm::RationalNumber>>();
case EquationSolverType::Eigen: return std::make_unique<EigenLinearEquationSolver<storm::RationalNumber>>();
case EquationSolverType::Elimination: return std::make_unique<EliminationLinearEquationSolver<storm::RationalNumber>>();
case EquationSolverType::Topological: return std::make_unique<TopologicalLinearEquationSolver<storm::RationalNumber>>();
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type.");
return nullptr;
@ -198,6 +200,7 @@ namespace storm {
switch (type) {
case EquationSolverType::Eigen: return std::make_unique<EigenLinearEquationSolver<storm::RationalFunction>>();
case EquationSolverType::Elimination: return std::make_unique<EliminationLinearEquationSolver<storm::RationalFunction>>();
case EquationSolverType::Topological: return std::make_unique<TopologicalLinearEquationSolver<storm::RationalFunction>>();
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type.");
return nullptr;
@ -209,7 +212,7 @@ namespace storm {
EquationSolverType type = env.solver().getLinearEquationSolverType();
// Adjust the solver type if none was specified and we want sound computations
if (env.solver().isForceSoundness() && task != LinearEquationSolverTask::Multiply && type != EquationSolverType::Native && type != EquationSolverType::Eigen && type != EquationSolverType::Elimination) {
if (env.solver().isForceSoundness() && task != LinearEquationSolverTask::Multiply && type != EquationSolverType::Native && type != EquationSolverType::Eigen && type != EquationSolverType::Elimination && type != EquationSolverType::Topological) {
if (env.solver().isLinearEquationSolverTypeSetFromDefaultValue()) {
type = EquationSolverType::Native;
STORM_LOG_INFO("Selecting '" + toString(type) + "' as the linear equation solver to guarantee sound results. If you want to override this, please explicitly specify a different solver.");
@ -223,6 +226,7 @@ namespace storm {
case EquationSolverType::Native: return std::make_unique<NativeLinearEquationSolver<ValueType>>();
case EquationSolverType::Eigen: return std::make_unique<EigenLinearEquationSolver<ValueType>>();
case EquationSolverType::Elimination: return std::make_unique<EliminationLinearEquationSolver<ValueType>>();
case EquationSolverType::Topological: return std::make_unique<TopologicalLinearEquationSolver<ValueType>>();
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidEnvironmentException, "Unknown solver type.");
return nullptr;

2
src/storm/solver/SolverSelectionOptions.cpp

@ -50,6 +50,8 @@ namespace storm {
return "Eigen";
case EquationSolverType::Elimination:
return "Elimination";
case EquationSolverType::Topological:
return "Topological";
}
return "invalid";
}

2
src/storm/solver/SolverSelectionOptions.h

@ -11,7 +11,7 @@ namespace storm {
ExtendEnumsWithSelectionField(LraMethod, LinearProgramming, ValueIteration)
ExtendEnumsWithSelectionField(LpSolverType, Gurobi, Glpk, Z3)
ExtendEnumsWithSelectionField(EquationSolverType, Native, Gmmxx, Eigen, Elimination)
ExtendEnumsWithSelectionField(EquationSolverType, Native, Gmmxx, Eigen, Elimination, Topological)
ExtendEnumsWithSelectionField(SmtSolverType, Z3, Mathsat)
ExtendEnumsWithSelectionField(NativeLinearEquationSolverMethod, Jacobi, GaussSeidel, SOR, WalkerChae, Power, RationalSearch, QuickPower)

344
src/storm/solver/TopologicalLinearEquationSolver.cpp

@ -0,0 +1,344 @@
#include "storm/solver/TopologicalLinearEquationSolver.h"
#include "storm/environment/solver/TopologicalLinearEquationSolverEnvironment.h"
#include "storm/utility/constants.h"
#include "storm/utility/vector.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/InvalidEnvironmentException.h"
#include "storm/exceptions/UnexpectedException.h"
namespace storm {
namespace solver {
template<typename ValueType>
TopologicalLinearEquationSolver<ValueType>::TopologicalLinearEquationSolver() : localA(nullptr), A(nullptr) {
// Intentionally left empty.
}
template<typename ValueType>
TopologicalLinearEquationSolver<ValueType>::TopologicalLinearEquationSolver(storm::storage::SparseMatrix<ValueType> const& A) : localA(nullptr), A(nullptr) {
this->setMatrix(A);
}
template<typename ValueType>
TopologicalLinearEquationSolver<ValueType>::TopologicalLinearEquationSolver(storm::storage::SparseMatrix<ValueType>&& A) : localA(nullptr), A(nullptr) {
this->setMatrix(std::move(A));
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::setMatrix(storm::storage::SparseMatrix<ValueType> const& A) {
localA.reset();
this->A = &A;
clearCache();
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::setMatrix(storm::storage::SparseMatrix<ValueType>&& A) {
localA = std::make_unique<storm::storage::SparseMatrix<ValueType>>(std::move(A));
this->A = localA.get();
clearCache();
}
template<typename ValueType>
storm::Environment TopologicalLinearEquationSolver<ValueType>::getEnvironmentForUnderlyingSolver(storm::Environment const& env) const {
storm::Environment subEnv(env);
subEnv.solver().setLinearEquationSolverType(env.solver().topological().getUnderlyingSolverType());
return subEnv;
}
template<typename ValueType>
bool TopologicalLinearEquationSolver<ValueType>::internalSolveEquations(Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
if (!this->sortedSccDecomposition) {
STORM_LOG_TRACE("Creating SCC decomposition.");
createSortedSccDecomposition();
}
storm::Environment sccSolverEnvironment = getEnvironmentForUnderlyingSolver(env);
// Handle the case where there is just one large SCC
if (this->sortedSccDecomposition->size() == 1) {
return solveFullyConnectedEquationSystem(sccSolverEnvironment, x, b);
}
storm::storage::BitVector sccAsBitVector(x.size(), false);
bool returnValue = true;
for (auto const& scc : *this->sortedSccDecomposition) {
if (scc.isTrivial()) {
returnValue = returnValue && solveTrivialScc(*scc.begin(), x, b);
} else {
sccAsBitVector.clear();
for (auto const& state : scc) {
sccAsBitVector.set(state, true);
}
returnValue = returnValue && solveScc(sccSolverEnvironment, sccAsBitVector, x, b);
}
}
if (!this->isCachingEnabled()) {
clearCache();
}
return returnValue;
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::createSortedSccDecomposition() const {
// Obtain the scc decomposition
auto sccDecomposition = storm::storage::StronglyConnectedComponentDecomposition<ValueType>(*this->A);
// Get a mapping from matrix row to the corresponding scc
STORM_LOG_THROW(sccDecomposition.size() < std::numeric_limits<uint32_t>::max(), storm::exceptions::UnexpectedException, "The number of SCCs is too large.");
std::vector<uint32_t> sccIndices(this->A->getRowCount(), std::numeric_limits<uint32_t>::max());
uint32_t sccIndex = 0;
for (auto const& scc : sccDecomposition) {
for (auto const& row : scc) {
sccIndices[row] = sccIndex;
}
++sccIndex;
}
// Prepare the resulting set of sorted sccs
this->sortedSccDecomposition = std::make_unique<std::vector<storm::storage::StronglyConnectedComponent>>();
std::vector<storm::storage::StronglyConnectedComponent>& sortedSCCs = *this->sortedSccDecomposition;
sortedSCCs.reserve(sccDecomposition.size());
// Find a topological sort via DFS.
storm::storage::BitVector unsortedSCCs(sccDecomposition.size(), true);
std::vector<uint32_t> sccStack;
uint32_t const token = std::numeric_limits<uint32_t>::max();
std::set<uint64_t> successorSCCs;
for (uint32_t firstUnsortedScc = 0; firstUnsortedScc < unsortedSCCs.size(); firstUnsortedScc = unsortedSCCs.getNextSetIndex(firstUnsortedScc + 1)) {
sccStack.push_back(firstUnsortedScc);
while (!sccStack.empty()) {
auto const& currentSccIndex = sccStack.back();
if (currentSccIndex != token) {
// Check whether the SCC is still unprocessed
if (unsortedSCCs.get(currentSccIndex)) {
// Explore the successors of the scc.
storm::storage::StronglyConnectedComponent const& currentScc = sccDecomposition.getBlock(currentSccIndex);
// We first push a token on the stack in order to recognize later when all successors of this SCC have been explored already.
sccStack.push_back(token);
// Now add all successors that are not already sorted.
// Successors should only be added once, so we first prepare a set of them and add them afterwards.
successorSCCs.clear();
for (auto const& row : currentScc) {
for (auto const& entry : this->A->getRow(row)) {
auto const& successorSCC = sccIndices[entry.getColumn()];
if (successorSCC != currentSccIndex && unsortedSCCs.get(successorSCC)) {
successorSCCs.insert(successorSCC);
}
}
}
sccStack.insert(sccStack.end(), successorSCCs.begin(), successorSCCs.end());
}
} else {
// all successors of the current scc have already been explored.
sccStack.pop_back(); // pop the token
sortedSCCs.push_back(std::move(sccDecomposition.getBlock(sccStack.back())));
unsortedSCCs.set(sccStack.back(), false);
sccStack.pop_back(); // pop the current scc index
}
}
}
}
template<typename ValueType>
bool TopologicalLinearEquationSolver<ValueType>::solveTrivialScc(uint64_t const& sccState, std::vector<ValueType>& globalX, std::vector<ValueType> const& globalB) const {
ValueType& xi = globalX[sccState];
xi = globalB[sccState];
bool hasDiagonalEntry = false;
ValueType denominator;
for (auto const& entry : this->A->getRow(sccState)) {
if (entry.getColumn() == sccState) {
STORM_LOG_ASSERT(!storm::utility::isOne(entry.getValue()), "Diagonal entry of fix point system has value 1.");
hasDiagonalEntry = true;
denominator = storm::utility::one<ValueType>() - entry.getValue();
} else {
xi += entry.getValue() * globalX[entry.getColumn()];
}
}
if (hasDiagonalEntry) {
xi /= denominator;
}
return true;
}
template<typename ValueType>
bool TopologicalLinearEquationSolver<ValueType>::solveFullyConnectedEquationSystem(storm::Environment const& sccSolverEnvironment, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
if (!this->sccSolver) {
this->sccSolver = GeneralLinearEquationSolverFactory<ValueType>().create(sccSolverEnvironment, LinearEquationSolverTask::SolveEquations);
this->sccSolver->setCachingEnabled(true);
this->sccSolver->setBoundsFromOtherSolver(*this);
if (this->sccSolver->getEquationProblemFormat(sccSolverEnvironment) == LinearEquationSolverProblemFormat::EquationSystem) {
// Convert the matrix to an equation system. Note that we need to insert diagonal entries.
storm::storage::SparseMatrix<ValueType> eqSysA(*this->A, true);
eqSysA.convertToEquationSystem();
this->sccSolver->setMatrix(std::move(eqSysA));
} else {
this->sccSolver->setMatrix(*this->A);
}
}
return this->sccSolver->solveEquations(sccSolverEnvironment, x, b);
}
template<typename ValueType>
bool TopologicalLinearEquationSolver<ValueType>::solveScc(storm::Environment const& sccSolverEnvironment, storm::storage::BitVector const& scc, std::vector<ValueType>& globalX, std::vector<ValueType> const& globalB) const {
// Set up the SCC solver
if (!this->sccSolver) {
this->sccSolver = GeneralLinearEquationSolverFactory<ValueType>().create(sccSolverEnvironment, LinearEquationSolverTask::SolveEquations);
this->sccSolver->setCachingEnabled(true);
}
// Matrix
bool asEquationSystem = this->sccSolver->getEquationProblemFormat(sccSolverEnvironment) == LinearEquationSolverProblemFormat::EquationSystem;
storm::storage::SparseMatrix<ValueType> sccA = this->A->getSubmatrix(true, scc, scc, asEquationSystem);
if (asEquationSystem) {
sccA.convertToEquationSystem();
}
this->sccSolver->setMatrix(std::move(sccA));
// x Vector
auto sccX = storm::utility::vector::filterVector(globalX, scc);
// b Vector
std::vector<ValueType> sccB;
sccB.reserve(scc.getNumberOfSetBits());
for (auto const& row : scc) {
ValueType bi = globalB[row];
for (auto const& entry : this->A->getRow(row)) {
if (!scc.get(entry.getColumn())) {
bi += entry.getValue() * globalX[entry.getColumn()];
}
}
sccB.push_back(std::move(bi));
}
// lower/upper bounds
if (this->hasLowerBound(storm::solver::AbstractEquationSolver<ValueType>::BoundType::Global)) {
this->sccSolver->setLowerBound(this->getLowerBound());
} else if (this->hasLowerBound(storm::solver::AbstractEquationSolver<ValueType>::BoundType::Local)) {
this->sccSolver->setLowerBounds(storm::utility::vector::filterVector(this->getLowerBounds(), scc));
}
if (this->hasUpperBound(storm::solver::AbstractEquationSolver<ValueType>::BoundType::Global)) {
this->sccSolver->setUpperBound(this->getUpperBound());
} else if (this->hasUpperBound(storm::solver::AbstractEquationSolver<ValueType>::BoundType::Local)) {
this->sccSolver->setUpperBounds(storm::utility::vector::filterVector(this->getUpperBounds(), scc));
}
return this->sccSolver->solveEquations(sccSolverEnvironment, sccX, sccB);
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::multiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const {
if (&x != &result) {
multiplier.multAdd(*A, x, b, result);
} else {
// If the two vectors are aliases, we need to create a temporary.
if (!this->cachedRowVector) {
this->cachedRowVector = std::make_unique<std::vector<ValueType>>(getMatrixRowCount());
}
multiplier.multAdd(*A, x, b, *this->cachedRowVector);
result.swap(*this->cachedRowVector);
if (!this->isCachingEnabled()) {
clearCache();
}
}
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::multiplyAndReduce(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices) const {
if (&x != &result) {
multiplier.multAddReduce(dir, rowGroupIndices, *A, x, b, result, choices);
} else {
// If the two vectors are aliases, we need to create a temporary.
if (!this->cachedRowVector) {
this->cachedRowVector = std::make_unique<std::vector<ValueType>>(getMatrixRowCount());
}
multiplier.multAddReduce(dir, rowGroupIndices, *A, x, b, *this->cachedRowVector, choices);
result.swap(*this->cachedRowVector);
if (!this->isCachingEnabled()) {
clearCache();
}
}
}
template<typename ValueType>
bool TopologicalLinearEquationSolver<ValueType>::supportsGaussSeidelMultiplication() const {
return true;
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::multiplyGaussSeidel(std::vector<ValueType>& x, std::vector<ValueType> const* b) const {
STORM_LOG_ASSERT(this->A->getRowCount() == this->A->getColumnCount(), "This function is only applicable for square matrices.");
multiplier.multAddGaussSeidelBackward(*A, x, b);
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices) const {
multiplier.multAddReduceGaussSeidelBackward(dir, rowGroupIndices, *A, x, b, choices);
}
template<typename ValueType>
LinearEquationSolverProblemFormat TopologicalLinearEquationSolver<ValueType>::getEquationProblemFormat(Environment const& env) const {
return LinearEquationSolverProblemFormat::FixedPointSystem;
}
template<typename ValueType>
LinearEquationSolverRequirements TopologicalLinearEquationSolver<ValueType>::getRequirements(Environment const& env, LinearEquationSolverTask const& task) const {
// Return the requirements of the underlying solver
return GeneralLinearEquationSolverFactory<ValueType>().getRequirements(getEnvironmentForUnderlyingSolver(env), task);
}
template<typename ValueType>
void TopologicalLinearEquationSolver<ValueType>::clearCache() const {
sortedSccDecomposition.reset();
sccSolver.reset();
LinearEquationSolver<ValueType>::clearCache();
}
template<typename ValueType>
uint64_t TopologicalLinearEquationSolver<ValueType>::getMatrixRowCount() const {
return this->A->getRowCount();
}
template<typename ValueType>
uint64_t TopologicalLinearEquationSolver<ValueType>::getMatrixColumnCount() const {
return this->A->getColumnCount();
}
template<typename ValueType>
std::unique_ptr<storm::solver::LinearEquationSolver<ValueType>> TopologicalLinearEquationSolverFactory<ValueType>::create(Environment const& env, LinearEquationSolverTask const& task) const {
return std::make_unique<storm::solver::TopologicalLinearEquationSolver<ValueType>>();
}
template<typename ValueType>
std::unique_ptr<LinearEquationSolverFactory<ValueType>> TopologicalLinearEquationSolverFactory<ValueType>::clone() const {
return std::make_unique<TopologicalLinearEquationSolverFactory<ValueType>>(*this);
}
// Explicitly instantiate the linear equation solver.
template class TopologicalLinearEquationSolver<double>;
template class TopologicalLinearEquationSolverFactory<double>;
#ifdef STORM_HAVE_CARL
template class TopologicalLinearEquationSolver<storm::RationalNumber>;
template class TopologicalLinearEquationSolverFactory<storm::RationalNumber>;
template class TopologicalLinearEquationSolver<storm::RationalFunction>;
template class TopologicalLinearEquationSolverFactory<storm::RationalFunction>;
#endif
}
}

84
src/storm/solver/TopologicalLinearEquationSolver.h

@ -0,0 +1,84 @@
#pragma once
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/SolverSelectionOptions.h"
#include "storm/solver/NativeMultiplier.h"
#include "storm/storage/StronglyConnectedComponentDecomposition.h"
namespace storm {
class Environment;
namespace solver {
template<typename ValueType>
class TopologicalLinearEquationSolver : public LinearEquationSolver<ValueType> {
public:
TopologicalLinearEquationSolver();
TopologicalLinearEquationSolver(storm::storage::SparseMatrix<ValueType> const& A);
TopologicalLinearEquationSolver(storm::storage::SparseMatrix<ValueType>&& A);
virtual void setMatrix(storm::storage::SparseMatrix<ValueType> const& A) override;
virtual void setMatrix(storm::storage::SparseMatrix<ValueType>&& A) override;
virtual void multiply(std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const override;
virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices = nullptr) const override;
virtual bool supportsGaussSeidelMultiplication() const override;
virtual void multiplyGaussSeidel(std::vector<ValueType>& x, std::vector<ValueType> const* b) const override;
virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<uint_fast64_t>* choices = nullptr) const override;
virtual LinearEquationSolverProblemFormat getEquationProblemFormat(storm::Environment const& env) const override;
virtual LinearEquationSolverRequirements getRequirements(Environment const& env, LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) const override;
virtual void clearCache() const override;
protected:
virtual bool internalSolveEquations(storm::Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const& b) const override;
private:
virtual uint64_t getMatrixRowCount() const override;
virtual uint64_t getMatrixColumnCount() const override;
storm::Environment getEnvironmentForUnderlyingSolver(storm::Environment const& env) const;
// Creates an SCC decomposition and sorts the SCCs according to a topological sort.
void createSortedSccDecomposition() const;
// Solves the SCC with the given index
// ... for the case that the SCC is trivial
bool solveTrivialScc(uint64_t const& sccState, std::vector<ValueType>& globalX, std::vector<ValueType> const& globalB) const;
// ... for the case that there is just one large SCC
bool solveFullyConnectedEquationSystem(storm::Environment const& sccSolverEnvironment, std::vector<ValueType>& x, std::vector<ValueType> const& b) const;
// ... for the remaining cases (1 < scc.size() < x.size())
bool solveScc(storm::Environment const& sccSolverEnvironment, storm::storage::BitVector const& scc, std::vector<ValueType>& globalX, std::vector<ValueType> const& globalB) const;
// If the solver takes posession of the matrix, we store the moved matrix in this member, so it gets deleted
// when the solver is destructed.
std::unique_ptr<storm::storage::SparseMatrix<ValueType>> localA;
// A pointer to the original sparse matrix given to this solver. If the solver takes posession of the matrix
// the pointer refers to localA.
storm::storage::SparseMatrix<ValueType> const* A;
// An object to dispatch all multiplication operations.
NativeMultiplier<ValueType> multiplier;
// cached auxiliary data
mutable std::unique_ptr<std::vector<storm::storage::StronglyConnectedComponent>> sortedSccDecomposition;
mutable std::unique_ptr<storm::solver::LinearEquationSolver<ValueType>> sccSolver;
};
template<typename ValueType>
class TopologicalLinearEquationSolverFactory : public LinearEquationSolverFactory<ValueType> {
public:
using LinearEquationSolverFactory<ValueType>::create;
virtual std::unique_ptr<storm::solver::LinearEquationSolver<ValueType>> create(Environment const& env, LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) const override;
virtual std::unique_ptr<LinearEquationSolverFactory<ValueType>> clone() const override;
};
}
}
Loading…
Cancel
Save