diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index 4bb4cc766..cbb1f30b9 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -10,7 +10,6 @@ #include "storm/utility/KwekMehlhorn.h" #include "storm/utility/NumberTraits.h" -#include "storm/solver/helper/SoundValueIterationHelper.h" #include "storm/utility/Stopwatch.h" #include "storm/utility/vector.h" #include "storm/utility/macros.h" @@ -608,21 +607,23 @@ namespace storm { template bool IterativeMinMaxLinearEquationSolver::solveEquationsSoundValueIteration(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { - // Prepare the solution vectors. + // Prepare the solution vectors and the helper. assert(x.size() == this->A->getRowGroupCount()); if (!this->auxiliaryRowGroupVector) { this->auxiliaryRowGroupVector = std::make_unique>(); } - - // TODO: implement caching for the helper - storm::solver::helper::SoundValueIterationHelper helper(*this->A, x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().minMax().getPrecision())); - + if (!this->soundValueIterationHelper) { + this->soundValueIterationHelper = std::make_unique>(*this->A, x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().minMax().getPrecision())); + } else { + this->soundValueIterationHelper = std::make_unique>(std::move(*this->soundValueIterationHelper), x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().minMax().getPrecision())); + } + // Prepare initial bounds for the solution (if given) if (this->hasLowerBound()) { - helper.setLowerBound(this->getLowerBound(true)); + this->soundValueIterationHelper->setLowerBound(this->getLowerBound(true)); } if (this->hasUpperBound()) { - helper.setUpperBound(this->getUpperBound(true)); + this->soundValueIterationHelper->setUpperBound(this->getUpperBound(true)); } storm::storage::BitVector const* relevantValuesPtr = nullptr; @@ -635,8 +636,8 @@ namespace storm { uint64_t iterations = 0; while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) { - helper.performIterationStep(dir, b); - if (helper.checkConvergenceUpdateBounds(dir, relevantValuesPtr)) { + this->soundValueIterationHelper->performIterationStep(dir, b); + if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(dir, relevantValuesPtr)) { status = SolverStatus::Converged; } @@ -648,7 +649,7 @@ namespace storm { // Potentially show progress. this->showProgressIterative(iterations); } - helper.setSolutionVector(); + this->soundValueIterationHelper->setSolutionVector(); // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { @@ -1000,6 +1001,7 @@ namespace storm { multiplierA.reset(); auxiliaryRowGroupVector.reset(); auxiliaryRowGroupVector2.reset(); + soundValueIterationHelper.reset(); StandardMinMaxLinearEquationSolver::clearCache(); } diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h index 1a5501089..33da6e68e 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h @@ -7,6 +7,7 @@ #include "storm/solver/LinearEquationSolver.h" #include "storm/solver/Multiplier.h" #include "storm/solver/StandardMinMaxLinearEquationSolver.h" +#include "storm/solver/helper/SoundValueIterationHelper.h" #include "storm/solver/SolverStatus.h" @@ -80,6 +81,7 @@ namespace storm { mutable std::unique_ptr> multiplierA; mutable std::unique_ptr> auxiliaryRowGroupVector; // A.rowGroupCount() entries mutable std::unique_ptr> auxiliaryRowGroupVector2; // A.rowGroupCount() entries + mutable std::unique_ptr> soundValueIterationHelper; SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const; static void reportStatus(SolverStatus status, uint64_t iterations); diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index e7d1cf5e5..9cd6b14d2 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -568,21 +568,23 @@ namespace storm { template bool NativeLinearEquationSolver::solveEquationsSoundValueIteration(Environment const& env, std::vector& x, std::vector const& b) const { - // Prepare the solution vectors. + // Prepare the solution vectors and the helper. assert(x.size() == this->A->getRowCount()); if (!this->cachedRowVector) { this->cachedRowVector = std::make_unique>(); } - - // TODO: implement caching for the helper - storm::solver::helper::SoundValueIterationHelper helper(*this->A, x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().native().getPrecision())); + if (!this->soundValueIterationHelper) { + this->soundValueIterationHelper = std::make_unique>(*this->A, x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().native().getPrecision())); + } else { + this->soundValueIterationHelper = std::make_unique>(std::move(*this->soundValueIterationHelper), x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber(env.solver().native().getPrecision())); + } // Prepare initial bounds for the solution (if given) if (this->hasLowerBound()) { - helper.setLowerBound(this->getLowerBound(true)); + this->soundValueIterationHelper->setLowerBound(this->getLowerBound(true)); } if (this->hasUpperBound()) { - helper.setUpperBound(this->getUpperBound(true)); + this->soundValueIterationHelper->setUpperBound(this->getUpperBound(true)); } storm::storage::BitVector const* relevantValuesPtr = nullptr; @@ -596,8 +598,8 @@ namespace storm { uint64_t iterations = 0; while (!converged && iterations < env.solver().native().getMaximalNumberOfIterations()) { - helper.performIterationStep(b); - if (helper.checkConvergenceUpdateBounds(relevantValuesPtr)) { + this->soundValueIterationHelper->performIterationStep(b); + if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) { converged = true; } @@ -610,7 +612,7 @@ namespace storm { // Potentially show progress. this->showProgressIterative(iterations); } - helper.setSolutionVector(); + this->soundValueIterationHelper->setSolutionVector(); this->logIterations(converged, terminate, iterations); @@ -973,6 +975,7 @@ namespace storm { cachedRowVector2.reset(); walkerChaeData.reset(); multiplier.reset(); + soundValueIterationHelper.reset(); LinearEquationSolver::clearCache(); } diff --git a/src/storm/solver/NativeLinearEquationSolver.h b/src/storm/solver/NativeLinearEquationSolver.h index 6622de17f..0cc493535 100644 --- a/src/storm/solver/NativeLinearEquationSolver.h +++ b/src/storm/solver/NativeLinearEquationSolver.h @@ -8,6 +8,7 @@ #include "storm/solver/SolverSelectionOptions.h" #include "storm/solver/NativeMultiplier.h" #include "storm/solver/SolverStatus.h" +#include "storm/solver/helper/SoundValueIterationHelper.h" #include "storm/utility/NumberTraits.h" @@ -93,6 +94,7 @@ namespace storm { // cached auxiliary data mutable std::unique_ptr> cachedRowVector2; // A.getRowCount() rows + mutable std::unique_ptr> soundValueIterationHelper; struct JacobiDecomposition { JacobiDecomposition(Environment const& env, storm::storage::SparseMatrix const& A); diff --git a/src/storm/solver/helper/SoundValueIterationHelper.cpp b/src/storm/solver/helper/SoundValueIterationHelper.cpp index c208cea55..4a831e7c7 100644 --- a/src/storm/solver/helper/SoundValueIterationHelper.cpp +++ b/src/storm/solver/helper/SoundValueIterationHelper.cpp @@ -296,7 +296,6 @@ namespace storm { << "."); } - template bool SoundValueIterationHelper::checkConvergencePhase1() {