diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index 2c8eb5481..587aa3169 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -22,17 +22,17 @@ namespace storm { namespace solver { template - IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver(std::move(linearEquationSolverFactory)) { + IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(std::unique_ptr>&& linearEquationSolverFactory) : linearEquationSolverFactory(std::move(linearEquationSolverFactory)) { // Intentionally left empty } template - IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix const& A, std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver(A, std::move(linearEquationSolverFactory)) { + IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix const& A, std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver(A), linearEquationSolverFactory(std::move(linearEquationSolverFactory)) { // Intentionally left empty. } template - IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix&& A, std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver(std::move(A), std::move(linearEquationSolverFactory)) { + IterativeMinMaxLinearEquationSolver::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix&& A, std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver(std::move(A)), linearEquationSolverFactory(std::move(linearEquationSolverFactory)) { // Intentionally left empty. } @@ -221,12 +221,11 @@ namespace storm { MinMaxLinearEquationSolverRequirements IterativeMinMaxLinearEquationSolver::getRequirements(Environment const& env, boost::optional const& direction, bool const& hasInitialScheduler) const { auto method = getMethod(env, storm::NumberTraits::IsExact); - // Start by getting the requirements of the linear equation solver. - LinearEquationSolverTask linEqTask = LinearEquationSolverTask::Unspecified; - if ((method == MinMaxMethod::ValueIteration && !this->hasInitialScheduler() && !hasInitialScheduler) || method == MinMaxMethod::RationalSearch || method == MinMaxMethod::SoundValueIteration || method == MinMaxMethod::IntervalIteration) { - linEqTask = LinearEquationSolverTask::Multiply; - } - MinMaxLinearEquationSolverRequirements requirements(this->linearEquationSolverFactory->getRequirements(env, linEqTask)); + // Check whether a linear equation solver is needed and potentially start with its requirements + bool needsLinEqSolver = false; + needsLinEqSolver |= method == MinMaxMethod::PolicyIteration; + needsLinEqSolver |= method == MinMaxMethod::ValueIteration && (this->hasInitialScheduler() || hasInitialScheduler); + MinMaxLinearEquationSolverRequirements requirements = needsLinEqSolver ? MinMaxLinearEquationSolverRequirements(this->linearEquationSolverFactory->getRequirements(env)) : MinMaxLinearEquationSolverRequirements(); if (method == MinMaxMethod::ValueIteration) { if (!this->hasUniqueSolution()) { // Traditional value iteration has no requirements if the solution is unique. @@ -275,15 +274,15 @@ namespace storm { } template - typename IterativeMinMaxLinearEquationSolver::ValueIterationResult IterativeMinMaxLinearEquationSolver::performValueIteration(OptimizationDirection dir, std::vector*& currentX, std::vector*& newX, std::vector const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const { + typename IterativeMinMaxLinearEquationSolver::ValueIterationResult IterativeMinMaxLinearEquationSolver::performValueIteration(Environment const& env, OptimizationDirection dir, std::vector*& currentX, std::vector*& newX, std::vector const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const { STORM_LOG_ASSERT(currentX != newX, "Vectors must not be aliased."); - // Get handle to linear equation solver. - storm::solver::LinearEquationSolver const& linearEquationSolver = *this->linEqSolverA; + // Get handle to multiplier. + storm::solver::Multiplier const& multiplier = *this->multiplierA; // Allow aliased multiplications. - bool useGaussSeidelMultiplication = linearEquationSolver.supportsGaussSeidelMultiplication() && multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel; + bool useGaussSeidelMultiplication = multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel; // Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations. uint64_t iterations = currentIterations; @@ -296,9 +295,9 @@ namespace storm { if (useGaussSeidelMultiplication) { // Copy over the current vector so we can modify it in-place. *newX = *currentX; - linearEquationSolver.multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *newX, &b); + multiplier.multiplyAndReduceGaussSeidel(env, dir, *newX, &b); } else { - linearEquationSolver.multiplyAndReduce(dir, this->A->getRowGroupIndices(), *currentX, &b, *newX); + multiplier.multiplyAndReduce(env, dir, *currentX, &b, *newX); } // Determine whether the method converged. @@ -325,9 +324,8 @@ namespace storm { template bool IterativeMinMaxLinearEquationSolver::solveEquationsValueIteration(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { - if (!this->linEqSolverA) { - this->createLinearEquationSolver(env); - this->linEqSolverA->setCachingEnabled(true); + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); } if (!auxiliaryRowGroupVector) { @@ -387,7 +385,7 @@ namespace storm { std::vector* currentX = &x; this->startMeasureProgress(); - ValueIterationResult result = performValueIteration(dir, currentX, newX, b, storm::utility::convertNumber(env.solver().minMax().getPrecision()), env.solver().minMax().getRelativeTerminationCriterion(), guarantee, 0, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle()); + ValueIterationResult result = performValueIteration(env, dir, currentX, newX, b, storm::utility::convertNumber(env.solver().minMax().getPrecision()), env.solver().minMax().getRelativeTerminationCriterion(), guarantee, 0, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle()); // Swap the result into the output x. if (currentX == auxiliaryRowGroupVector.get()) { @@ -399,7 +397,7 @@ namespace storm { // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { this->schedulerChoices = std::vector(this->A->getRowGroupCount()); - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *auxiliaryRowGroupVector.get(), &this->schedulerChoices.get()); + this->multiplierA->multiplyAndReduce(env, dir, x, &b, *auxiliaryRowGroupVector.get(), &this->schedulerChoices.get()); } if (!this->isCachingEnabled()) { @@ -443,9 +441,8 @@ namespace storm { bool IterativeMinMaxLinearEquationSolver::solveEquationsIntervalIteration(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { STORM_LOG_THROW(this->hasUpperBound(), storm::exceptions::UnmetRequirementException, "Solver requires upper bound, but none was given."); - if (!this->linEqSolverA) { - this->createLinearEquationSolver(env); - this->linEqSolverA->setCachingEnabled(true); + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); } if (!auxiliaryRowGroupVector) { @@ -453,7 +450,7 @@ namespace storm { } // Allow aliased multiplications. - bool useGaussSeidelMultiplication = this->linEqSolverA->supportsGaussSeidelMultiplication() && env.solver().minMax().getMultiplicationStyle() == storm::solver::MultiplicationStyle::GaussSeidel; + bool useGaussSeidelMultiplication = env.solver().minMax().getMultiplicationStyle() == storm::solver::MultiplicationStyle::GaussSeidel; std::vector* lowerX = &x; this->createLowerBoundsVector(*lowerX); @@ -497,22 +494,22 @@ namespace storm { if (useDiffs) { preserveOldRelevantValues(*lowerX, this->getRelevantValues(), oldValues); } - this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *lowerX, &b); + this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *lowerX, &b); if (useDiffs) { maxLowerDiff = computeMaxAbsDiff(*lowerX, this->getRelevantValues(), oldValues); preserveOldRelevantValues(*upperX, this->getRelevantValues(), oldValues); } - this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *upperX, &b); + this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *upperX, &b); if (useDiffs) { maxUpperDiff = computeMaxAbsDiff(*upperX, this->getRelevantValues(), oldValues); } } else { - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *lowerX, &b, *tmp); + this->multiplierA->multiplyAndReduce(env, dir, *lowerX, &b, *tmp); if (useDiffs) { maxLowerDiff = computeMaxAbsDiff(*lowerX, *tmp, this->getRelevantValues()); } std::swap(lowerX, tmp); - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *upperX, &b, *tmp); + this->multiplierA->multiplyAndReduce(env, dir, *upperX, &b, *tmp); if (useDiffs) { maxUpperDiff = computeMaxAbsDiff(*upperX, *tmp, this->getRelevantValues()); } @@ -525,7 +522,7 @@ namespace storm { if (useDiffs) { preserveOldRelevantValues(*lowerX, this->getRelevantValues(), oldValues); } - this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *lowerX, &b); + this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *lowerX, &b); if (useDiffs) { maxLowerDiff = computeMaxAbsDiff(*lowerX, this->getRelevantValues(), oldValues); } @@ -534,7 +531,7 @@ namespace storm { if (useDiffs) { preserveOldRelevantValues(*upperX, this->getRelevantValues(), oldValues); } - this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *upperX, &b); + this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *upperX, &b); if (useDiffs) { maxUpperDiff = computeMaxAbsDiff(*upperX, this->getRelevantValues(), oldValues); } @@ -542,14 +539,14 @@ namespace storm { } } else { if (maxLowerDiff >= maxUpperDiff) { - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *lowerX, &b, *tmp); + this->multiplierA->multiplyAndReduce(env, dir, *lowerX, &b, *tmp); if (useDiffs) { maxLowerDiff = computeMaxAbsDiff(*lowerX, *tmp, this->getRelevantValues()); } std::swap(tmp, lowerX); lowerStep = true; } else { - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *upperX, &b, *tmp); + this->multiplierA->multiplyAndReduce(env, dir, *upperX, &b, *tmp); if (useDiffs) { maxUpperDiff = computeMaxAbsDiff(*upperX, *tmp, this->getRelevantValues()); } @@ -604,7 +601,7 @@ namespace storm { // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { this->schedulerChoices = std::vector(this->A->getRowGroupCount()); - this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get()); + this->multiplierA->multiplyAndReduce(env, dir, x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get()); } if (!this->isCachingEnabled()) { @@ -669,19 +666,24 @@ namespace storm { return maximize(dir) ? minIndex : maxIndex; } - void multiplyRow(uint64_t const& row, storm::storage::SparseMatrix const& A, ValueType const& bi, ValueType& xi, ValueType& yi) { + void multiplyRow(uint64_t const& row, storm::storage::SparseMatrix const& A, storm::solver::Multiplier const& multiplier, ValueType const& bi, ValueType& xi, ValueType& yi) { + xi = multiplier.multiplyRow(row, x, bi); + yi = multiplier.multiplyRow(row, y, storm::utility::zero()); + + /* xi = bi; yi = storm::utility::zero(); for (auto const& entry : A.getRow(row)) { xi += entry.getValue() * x[entry.getColumn()]; yi += entry.getValue() * y[entry.getColumn()]; } + */ } template - void performIterationStep(storm::storage::SparseMatrix const& A, std::vector const& b) { + void performIterationStep(storm::storage::SparseMatrix const& A, storm::solver::Multiplier const& multiplier, std::vector const& b) { if (!decisionValueBlocks) { - performIterationStepUpdateDecisionValue(A, b); + performIterationStepUpdateDecisionValue(A, multiplier, b); } else { assert(decisionValue == getPrimaryBound()); auto xIt = x.rbegin(); @@ -693,7 +695,7 @@ namespace storm { // Perform the iteration for the first row in the group uint64_t row = *groupStartIt; ValueType xBest, yBest; - multiplyRow(row, A, b[row], xBest, yBest); + multiplyRow(row, A, multiplier, b[row], xBest, yBest); ++row; // Only do more work if there are still rows in this row group if (row != groupEnd) { @@ -701,7 +703,7 @@ namespace storm { ValueType bestValue = xBest + yBest * getPrimaryBound(); for (;row < groupEnd; ++row) { // Get the multiplication results - multiplyRow(row, A, b[row], xi, yi); + multiplyRow(row, A, multiplier, b[row], xi, yi); ValueType currentValue = xi + yi * getPrimaryBound(); // Check if the current row is better then the previously found one if (better(currentValue, bestValue)) { @@ -722,7 +724,7 @@ namespace storm { } template - void performIterationStepUpdateDecisionValue(storm::storage::SparseMatrix const& A, std::vector const& b) { + void performIterationStepUpdateDecisionValue(storm::storage::SparseMatrix const& A, storm::solver::Multiplier const& multiplier, std::vector const& b) { auto xIt = x.rbegin(); auto yIt = y.rbegin(); auto groupStartIt = A.getRowGroupIndices().rbegin(); @@ -732,7 +734,7 @@ namespace storm { // Perform the iteration for the first row in the group uint64_t row = *groupStartIt; ValueType xBest, yBest; - multiplyRow(row, A, b[row], xBest, yBest); + multiplyRow(row, A, multiplier, b[row], xBest, yBest); ++row; // Only do more work if there are still rows in this row group if (row != groupEnd) { @@ -742,7 +744,7 @@ namespace storm { ValueType bestValue = xBest + yBest * getPrimaryBound(); for (;row < groupEnd; ++row) { // Get the multiplication results - multiplyRow(row, A, b[row], xi, yi); + multiplyRow(row, A, multiplier, b[row], xi, yi); ValueType currentValue = xi + yi * getPrimaryBound(); // Check if the current row is better then the previously found one if (better(currentValue, bestValue)) { @@ -769,7 +771,7 @@ namespace storm { } } else { for (;row < groupEnd; ++row) { - multiplyRow(row, A, b[row], xi, yi); + multiplyRow(row, A, multiplier, b[row], xi, yi); // Update the best choice if (yi > yBest || (yi == yBest && better(xi, xBest))) { xTmp[xyTmpIndex] = std::move(xBest); @@ -978,6 +980,10 @@ namespace storm { this->auxiliaryRowGroupVector = std::make_unique>(); } + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); + } + SoundValueIterationHelper helper(x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().minMax().getPrecision()), this->A->getSizeOfLargestRowGroup()); // Prepare initial bounds for the solution (if given) @@ -999,13 +1005,13 @@ namespace storm { while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) { if (minimize(dir)) { - helper.template performIterationStep(*this->A, b); + helper.template performIterationStep(*this->A, *this->multiplierA, b); if (helper.template checkConvergenceUpdateBounds(relevantValuesPtr)) { status = SolverStatus::Converged; } } else { assert(maximize(dir)); - helper.template performIterationStep(*this->A, b); + helper.template performIterationStep(*this->A, *this->multiplierA, b); if (helper.template checkConvergenceUpdateBounds(relevantValuesPtr)) { status = SolverStatus::Converged; } @@ -1085,11 +1091,6 @@ namespace storm { return false; } - template - void IterativeMinMaxLinearEquationSolver::createLinearEquationSolver(Environment const& env) const { - this->linEqSolverA = this->linearEquationSolverFactory->create(env, *this->A, LinearEquationSolverTask::Multiply); - } - template template typename std::enable_if::value && !NumberTraits::IsExact, bool>::type IterativeMinMaxLinearEquationSolver::solveEquationsRationalSearchHelper(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { @@ -1100,9 +1101,8 @@ namespace storm { std::vector rationalX(x.size()); std::vector rationalB = storm::utility::vector::convertNumericVector(b); - if (!this->linEqSolverA) { - this->createLinearEquationSolver(env); - this->linEqSolverA->setCachingEnabled(true); + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); } if (!auxiliaryRowGroupVector) { @@ -1130,9 +1130,8 @@ namespace storm { typename std::enable_if::value && NumberTraits::IsExact, bool>::type IterativeMinMaxLinearEquationSolver::solveEquationsRationalSearchHelper(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { // Version for when the overall value type is exact and the same type is to be used for the imprecise part. - if (!this->linEqSolverA) { - this->createLinearEquationSolver(env); - this->linEqSolverA->setCachingEnabled(true); + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); } if (!auxiliaryRowGroupVector) { @@ -1182,13 +1181,14 @@ namespace storm { // Create imprecise solver from the imprecise data. IterativeMinMaxLinearEquationSolver impreciseSolver(std::make_unique>()); impreciseSolver.setMatrix(impreciseA); - impreciseSolver.createLinearEquationSolver(env); impreciseSolver.setCachingEnabled(true); + impreciseSolver.multiplierA = storm::solver::MultiplierFactory().create(env, impreciseA); bool converged = false; try { // Forward the call to the core rational search routine. converged = solveEquationsRationalSearchHelper(env, dir, impreciseSolver, *this->A, x, b, impreciseA, impreciseX, impreciseB, impreciseTmpX); + impreciseSolver.clearCache(); } catch (storm::exceptions::PrecisionExceededException const& e) { STORM_LOG_WARN("Precision of value type was exceeded, trying to recover by switching to rational arithmetic."); @@ -1208,9 +1208,8 @@ namespace storm { impreciseB = std::vector(); impreciseA = storm::storage::SparseMatrix(); - if (!this->linEqSolverA) { - createLinearEquationSolver(env); - this->linEqSolverA->setCachingEnabled(true); + if (!this->multiplierA) { + this->multiplierA = storm::solver::MultiplierFactory().create(env, *this->A); } // Forward the call to the core rational search routine, but now with our value type as the imprecise value type. @@ -1270,7 +1269,7 @@ namespace storm { impreciseSolver.startMeasureProgress(); while (status == SolverStatus::InProgress && overallIterations < env.solver().minMax().getMaximalNumberOfIterations()) { // Perform value iteration with the current precision. - typename IterativeMinMaxLinearEquationSolver::ValueIterationResult result = impreciseSolver.performValueIteration(dir, currentX, newX, b, storm::utility::convertNumber(precision), env.solver().minMax().getRelativeTerminationCriterion(), SolverGuarantee::LessOrEqual, overallIterations, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle()); + typename IterativeMinMaxLinearEquationSolver::ValueIterationResult result = impreciseSolver.performValueIteration(env, dir, currentX, newX, b, storm::utility::convertNumber(precision), env.solver().minMax().getRelativeTerminationCriterion(), SolverGuarantee::LessOrEqual, overallIterations, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle()); // At this point, the result of the imprecise value iteration is stored in the (imprecise) current x. @@ -1370,45 +1369,16 @@ namespace storm { template void IterativeMinMaxLinearEquationSolver::clearCache() const { + multiplierA.reset(); auxiliaryRowGroupVector.reset(); auxiliaryRowGroupVector2.reset(); - rowGroupOrdering.reset(); StandardMinMaxLinearEquationSolver::clearCache(); } - template - IterativeMinMaxLinearEquationSolverFactory::IterativeMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory() { - // Intentionally left empty - } - - template - IterativeMinMaxLinearEquationSolverFactory::IterativeMinMaxLinearEquationSolverFactory(std::unique_ptr>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolverFactory(std::move(linearEquationSolverFactory)) { - // Intentionally left empty - } - - template - IterativeMinMaxLinearEquationSolverFactory::IterativeMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType) : StandardMinMaxLinearEquationSolverFactory(solverType) { - // Intentionally left empty - } - - template - std::unique_ptr> IterativeMinMaxLinearEquationSolverFactory::create(Environment const& env) const { - STORM_LOG_ASSERT(this->linearEquationSolverFactory, "Linear equation solver factory not initialized."); - - auto method = env.solver().minMax().getMethod(); - STORM_LOG_THROW(method == MinMaxMethod::ValueIteration || method == MinMaxMethod::PolicyIteration || method == MinMaxMethod::RationalSearch || method == MinMaxMethod::IntervalIteration || method == MinMaxMethod::SoundValueIteration, storm::exceptions::InvalidEnvironmentException, "This solver does not support the selected method."); - - std::unique_ptr> result = std::make_unique>(this->linearEquationSolverFactory->clone()); - result->setRequirementsChecked(this->isRequirementsCheckedSet()); - return result; - } - template class IterativeMinMaxLinearEquationSolver; - template class IterativeMinMaxLinearEquationSolverFactory; #ifdef STORM_HAVE_CARL template class IterativeMinMaxLinearEquationSolver; - template class IterativeMinMaxLinearEquationSolverFactory; #endif } } diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h index 5806ec4a2..1a5501089 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h @@ -5,6 +5,7 @@ #include "storm/utility/NumberTraits.h" #include "storm/solver/LinearEquationSolver.h" +#include "storm/solver/Multiplier.h" #include "storm/solver/StandardMinMaxLinearEquationSolver.h" #include "storm/solver/SolverStatus.h" @@ -68,31 +69,21 @@ namespace storm { template friend class IterativeMinMaxLinearEquationSolver; - ValueIterationResult performValueIteration(OptimizationDirection dir, std::vector*& currentX, std::vector*& newX, std::vector const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const; + ValueIterationResult performValueIteration(Environment const& env, OptimizationDirection dir, std::vector*& currentX, std::vector*& newX, std::vector const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const; void createLinearEquationSolver(Environment const& env) const; + /// The factory used to obtain linear equation solvers. + std::unique_ptr> linearEquationSolverFactory; + // possibly cached data + mutable std::unique_ptr> multiplierA; mutable std::unique_ptr> auxiliaryRowGroupVector; // A.rowGroupCount() entries mutable std::unique_ptr> auxiliaryRowGroupVector2; // A.rowGroupCount() entries - mutable std::unique_ptr> rowGroupOrdering; // A.rowGroupCount() entries SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const; static void reportStatus(SolverStatus status, uint64_t iterations); }; - template - class IterativeMinMaxLinearEquationSolverFactory : public StandardMinMaxLinearEquationSolverFactory { - public: - IterativeMinMaxLinearEquationSolverFactory(); - IterativeMinMaxLinearEquationSolverFactory(std::unique_ptr>&& linearEquationSolverFactory); - IterativeMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType); - - // Make the other create methods visible. - using MinMaxLinearEquationSolverFactory::create; - - virtual std::unique_ptr> create(Environment const& env) const override; - - }; } }