diff --git a/src/storm/solver/Multiplier.cpp b/src/storm/solver/Multiplier.cpp index e48b29df2..b2b6dfd5e 100644 --- a/src/storm/solver/Multiplier.cpp +++ b/src/storm/solver/Multiplier.cpp @@ -10,6 +10,7 @@ #include "storm/utility/macros.h" #include "storm/solver/SolverSelectionOptions.h" #include "storm/solver/NativeMultiplier.h" +#include "storm/solver/GmmxxMultiplier.h" #include "storm/environment/solver/MultiplierEnvironment.h" namespace storm { @@ -26,12 +27,12 @@ namespace storm { } template - void Multiplier::multiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector const& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) { + void Multiplier::multiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector const& x, std::vector const* b, std::vector& result, std::vector* choices) const { multiplyAndReduce(env, dir, this->matrix.getRowGroupIndices(), x, b, result, choices); } template - void Multiplier::multiplyAndReduceGaussSeidel(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, std::vector* choices = nullptr) { + void Multiplier::multiplyAndReduceGaussSeidel(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, std::vector* choices) const { multiplyAndReduceGaussSeidel(env, dir, this->matrix.getRowGroupIndices(), x, b, choices); } @@ -43,9 +44,9 @@ namespace storm { } template - void Multiplier::repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, uint64_t n) const { + void Multiplier::repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, uint64_t n) const { for (uint64_t i = 0; i < n; ++i) { - multiplyAndReduce(env, dir, rowGroupIndices, x, b, x); + multiplyAndReduce(env, dir, x, b, x); } } @@ -53,8 +54,7 @@ namespace storm { std::unique_ptr> MultiplierFactory::create(Environment const& env, storm::storage::SparseMatrix const& matrix) { switch (env.solver().multiplier().getType()) { case MultiplierType::Gmmxx: - //return std::make_unique>(matrix); - STORM_PRINT_AND_LOG("gmm mult not yet supported"); + return std::make_unique>(matrix); case MultiplierType::Native: return std::make_unique>(matrix); } diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index 47f07f6f8..c524513f2 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -301,11 +301,7 @@ namespace storm { } template - typename NativeLinearEquationSolver::PowerIterationResult NativeLinearEquationSolver::performPowerIteration(Environment const& env, std::vector*& currentX, std::vector*& newX, std::vector const& b, storm::solver::Multiplier const& multiplier, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maxIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const { - - if (!this->multiplier) { - this->multiplier = storm::solver::MultiplierFactory().create(env, *A); - } + typename NativeLinearEquationSolver::PowerIterationResult NativeLinearEquationSolver::performPowerIteration(Environment const& env, std::vector*& currentX, std::vector*& newX, std::vector const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maxIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const { bool useGaussSeidelMultiplication = multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel; @@ -317,9 +313,9 @@ namespace storm { while (!converged && !terminate && iterations < maxIterations) { if (useGaussSeidelMultiplication) { *newX = *currentX; - multiplier.multiplyGaussSeidel(env, *newX, &b); + this->multiplier->multiplyGaussSeidel(env, *newX, &b); } else { - multiplier.multiply(env, *currentX, &b, *newX); + this->multiplier->multiply(env, *currentX, &b, *newX); } // Now check for termination. @@ -369,7 +365,7 @@ namespace storm { // Forward call to power iteration implementation. this->startMeasureProgress(); ValueType precision = storm::utility::convertNumber(env.solver().native().getPrecision()); - PowerIterationResult result = this->performPowerIteration(env, currentX, newX, b, *this->multiplier, precision, env.solver().native().getRelativeTerminationCriterion(), guarantee, 0, env.solver().native().getMaximalNumberOfIterations(), env.solver().native().getPowerMethodMultiplicationStyle()); + PowerIterationResult result = this->performPowerIteration(env, currentX, newX, b, precision, env.solver().native().getRelativeTerminationCriterion(), guarantee, 0, env.solver().native().getMaximalNumberOfIterations(), env.solver().native().getPowerMethodMultiplicationStyle()); // Swap the result in place. if (currentX == this->cachedRowVector.get()) { @@ -596,7 +592,7 @@ namespace storm { void multiplyRow(uint64_t const& row, storm::storage::SparseMatrix const& A, storm::solver::Multiplier const& multiplier, ValueType const& bi, ValueType& xi, ValueType& yi) { xi = multiplier.multiplyRow(row, x, bi); - yi = multiplier.multiplyRow(row, y, storm::utility::zero()); + yi = multiplier.multiplyRow(row, y, storm::utility::zero()); /* xi = bi; yi = storm::utility::zero(); @@ -873,7 +869,7 @@ namespace storm { impreciseSolver.startMeasureProgress(); while (status == SolverStatus::InProgress && overallIterations < maxIter) { // Perform value iteration with the current precision. - typename NativeLinearEquationSolver::PowerIterationResult result = impreciseSolver.performPowerIteration(currentX, newX, b, storm::utility::convertNumber(precision), relative, SolverGuarantee::LessOrEqual, overallIterations, maxIter, multiplicationStyle); + typename NativeLinearEquationSolver::PowerIterationResult result = impreciseSolver.performPowerIteration(env, currentX, newX, b, storm::utility::convertNumber(precision), relative, SolverGuarantee::LessOrEqual, overallIterations, maxIter, multiplicationStyle); // At this point, the result of the imprecise value iteration is stored in the (imprecise) current x. @@ -1143,7 +1139,7 @@ namespace storm { } template - LinearEquationSolverRequirements NativeLinearEquationSolver::getRequirements(Environment const& env, LinearEquationSolverTask const& task) const { + LinearEquationSolverRequirements NativeLinearEquationSolver::getRequirements(Environment const& env) const { LinearEquationSolverRequirements requirements; if (env.solver().native().isForceBoundsSet()) { requirements.requireBounds(); diff --git a/src/storm/solver/NativeLinearEquationSolver.h b/src/storm/solver/NativeLinearEquationSolver.h index 7640f19eb..46ba3628e 100644 --- a/src/storm/solver/NativeLinearEquationSolver.h +++ b/src/storm/solver/NativeLinearEquationSolver.h @@ -30,13 +30,6 @@ namespace storm { virtual void setMatrix(storm::storage::SparseMatrix const& A) override; virtual void setMatrix(storm::storage::SparseMatrix&& A) override; - virtual void multiply(std::vector& x, std::vector const* b, std::vector& result) const override; - virtual void multiplyAndReduce(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) const override; - virtual bool supportsGaussSeidelMultiplication() const override; - virtual void multiplyGaussSeidel(std::vector& x, std::vector const* b) const override; - virtual void multiplyAndReduceGaussSeidel(OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector& x, std::vector const* b, std::vector* choices = nullptr) const override; - virtual ValueType multiplyRow(uint64_t const& rowIndex, std::vector const& x) const override; - virtual LinearEquationSolverProblemFormat getEquationProblemFormat(storm::Environment const& env) const override; virtual LinearEquationSolverRequirements getRequirements(Environment const& env) const override; diff --git a/src/storm/solver/NativeMultiplier.cpp b/src/storm/solver/NativeMultiplier.cpp index 5fe7dde67..defe462f4 100644 --- a/src/storm/solver/NativeMultiplier.cpp +++ b/src/storm/solver/NativeMultiplier.cpp @@ -32,7 +32,6 @@ namespace storm { template void NativeMultiplier::multiply(Environment const& env, std::vector const& x, std::vector const* b, std::vector& result) const { - STORM_LOG_ASSERT(getMultiplicationStyle() == MultiplicationStyle::Regular, "Unexpected Multiplicationstyle."); std::vector* target = &result; if (&x == &result) { if (this->cachedVector) { @@ -53,13 +52,12 @@ namespace storm { } template - void NativeMultiplier::multiplyGaussSeidel(Environment const& env, std::vector const& x, std::vector const* b) const { + void NativeMultiplier::multiplyGaussSeidel(Environment const& env, std::vector& x, std::vector const* b) const { this->matrix.multiplyWithVectorBackward(x, x, b); } template void NativeMultiplier::multiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector const& x, std::vector const* b, std::vector& result, std::vector* choices) const { - STORM_LOG_ASSERT(getMultiplicationStyle() == MultiplicationStyle::Regular, "Unexpected Multiplicationstyle."); std::vector* target = &result; if (&x == &result) { if (this->cachedVector) { diff --git a/src/storm/solver/NativeMultiplier.h b/src/storm/solver/NativeMultiplier.h index e91dc519e..e1cc4d01a 100644 --- a/src/storm/solver/NativeMultiplier.h +++ b/src/storm/solver/NativeMultiplier.h @@ -17,8 +17,6 @@ namespace storm { public: NativeMultiplier(storm::storage::SparseMatrix const& matrix); - virtual MultiplicationStyle getMultiplicationStyle() const override; - virtual void multiply(Environment const& env, std::vector const& x, std::vector const* b, std::vector& result) const override; virtual void multiplyGaussSeidel(Environment const& env, std::vector& x, std::vector const* b) const override; virtual void multiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector const& rowGroupIndices, std::vector const& x, std::vector const* b, std::vector& result, std::vector* choices = nullptr) const override; diff --git a/src/storm/solver/SolveGoal.h b/src/storm/solver/SolveGoal.h index d93b86b91..9bd6f7e73 100644 --- a/src/storm/solver/SolveGoal.h +++ b/src/storm/solver/SolveGoal.h @@ -7,7 +7,6 @@ #include "storm/solver/OptimizationDirection.h" #include "storm/logic/ComparisonType.h" #include "storm/storage/BitVector.h" -#include "storm/solver/LinearEquationSolverTask.h" #include "storm/solver/LinearEquationSolver.h" #include "storm/solver/MinMaxLinearEquationSolver.h" @@ -111,8 +110,8 @@ namespace storm { } template - std::unique_ptr> configureLinearEquationSolver(Environment const& env, SolveGoal&& goal, storm::solver::LinearEquationSolverFactory const& factory, MatrixType&& matrix, storm::solver::LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) { - std::unique_ptr> solver = factory.create(env, std::forward(matrix), task); + std::unique_ptr> configureLinearEquationSolver(Environment const& env, SolveGoal&& goal, storm::solver::LinearEquationSolverFactory const& factory, MatrixType&& matrix) { + std::unique_ptr> solver = factory.create(env, std::forward(matrix)); if (goal.isBounded()) { solver->setTerminationCondition(std::make_unique>(goal.relevantValues(), goal.boundIsStrict(), goal.thresholdValue(), goal.minimize())); } @@ -120,8 +119,8 @@ namespace storm { } template - std::unique_ptr> configureLinearEquationSolver(Environment const& env, SolveGoal&& goal, storm::solver::LinearEquationSolverFactory const& factory, MatrixType&& matrix, storm::solver::LinearEquationSolverTask const& task = LinearEquationSolverTask::Unspecified) { - std::unique_ptr> solver = factory.create(env, std::forward(matrix), task); + std::unique_ptr> configureLinearEquationSolver(Environment const& env, SolveGoal&& goal, storm::solver::LinearEquationSolverFactory const& factory, MatrixType&& matrix) { + std::unique_ptr> solver = factory.create(env, std::forward(matrix)); return solver; }