From b745b10b77f6bab1c1455fb1ecf86891e23995ec Mon Sep 17 00:00:00 2001 From: Matthias Volk Date: Thu, 5 Mar 2020 18:25:00 +0100 Subject: [PATCH] Moved reportStatus() and updateStatusIfNotConverged() to AbstractEquationSolver --- src/storm/solver/AbstractEquationSolver.cpp | 65 ++++++++- src/storm/solver/AbstractEquationSolver.h | 20 +++ .../IterativeMinMaxLinearEquationSolver.cpp | 49 ++----- .../IterativeMinMaxLinearEquationSolver.h | 2 - .../solver/NativeLinearEquationSolver.cpp | 130 +++++++++++------- src/storm/solver/StandardGameSolver.cpp | 16 +-- src/storm/solver/StandardGameSolver.h | 3 +- 7 files changed, 174 insertions(+), 111 deletions(-) diff --git a/src/storm/solver/AbstractEquationSolver.cpp b/src/storm/solver/AbstractEquationSolver.cpp index 69a2d0fad..031599847 100644 --- a/src/storm/solver/AbstractEquationSolver.cpp +++ b/src/storm/solver/AbstractEquationSolver.cpp @@ -2,14 +2,15 @@ #include "storm/adapters/RationalNumberAdapter.h" #include "storm/adapters/RationalFunctionAdapter.h" - +#include "storm/exceptions/InvalidOperationException.h" +#include "storm/exceptions/InvalidStateException.h" +#include "storm/exceptions/UnmetRequirementException.h" #include "storm/settings/SettingsManager.h" #include "storm/settings/modules/GeneralSettings.h" - #include "storm/utility/constants.h" #include "storm/utility/macros.h" -#include "storm/exceptions/UnmetRequirementException.h" -#include "storm/exceptions/InvalidOperationException.h" +#include "storm/utility/SignalHandler.h" + namespace storm { namespace solver { @@ -272,6 +273,62 @@ namespace storm { this->progressMeasurement->updateProgress(iteration); } } + + + template + void AbstractEquationSolver::reportStatus(SolverStatus status, boost::optional const& iterations) const { + if (iterations) { + switch (status) { + case SolverStatus::Converged: + STORM_LOG_TRACE("Iterative solver converged after " << iterations.get() << " iterations."); + break; + case SolverStatus::TerminatedEarly: + STORM_LOG_TRACE("Iterative solver terminated early after " << iterations.get() << " iterations."); + break; + case SolverStatus::MaximalIterationsExceeded: + STORM_LOG_WARN("Iterative solver did not converge after " << iterations.get() << " iterations."); + break; + case SolverStatus::Aborted: + STORM_LOG_WARN("Iterative solver was aborted."); + break; + default: + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly."); + } + } else { + switch (status) { + case SolverStatus::Converged: + STORM_LOG_TRACE("Solver converged."); + break; + case SolverStatus::TerminatedEarly: + STORM_LOG_TRACE("Solver terminated early."); + break; + case SolverStatus::MaximalIterationsExceeded: + STORM_LOG_ASSERT(false, "Non-iterative solver should not exceed maximal number of iterations."); + STORM_LOG_WARN("Solver did not converge."); + break; + case SolverStatus::Aborted: + STORM_LOG_WARN("Solver was aborted."); + break; + default: + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Solver terminated unexpectedly."); + } + } + } + + + template + SolverStatus AbstractEquationSolver::updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const { + if (status != SolverStatus::Converged) { + if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) { + status = SolverStatus::TerminatedEarly; + } else if (iterations >= maximalNumberOfIterations) { + status = SolverStatus::MaximalIterationsExceeded; + } else if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } + } + return status; + } template class AbstractEquationSolver; template class AbstractEquationSolver; diff --git a/src/storm/solver/AbstractEquationSolver.h b/src/storm/solver/AbstractEquationSolver.h index 9f9efb411..00fd7db9b 100644 --- a/src/storm/solver/AbstractEquationSolver.h +++ b/src/storm/solver/AbstractEquationSolver.h @@ -6,9 +6,11 @@ #include #include +#include "storm/solver/SolverStatus.h" #include "storm/solver/TerminationCondition.h" #include "storm/utility/ProgressMeasurement.h" + namespace storm { namespace solver { @@ -191,6 +193,24 @@ namespace storm { void createUpperBoundsVector(std::unique_ptr>& upperBoundsVector, uint64_t length) const; void createLowerBoundsVector(std::vector& lowerBoundsVector) const; + /*! + * Report the current status of the solver. + * @param status Solver status. + * @param iterations Number of iterations (if solver is iterative). + */ + void reportStatus(SolverStatus status, boost::optional const& iterations = boost::none) const; + + /*! + * Update the status of the solver with respect to convergence, early termination, abortion, etc. + * @param status Current status. + * @param x Vector x. + * @param iterations Current number of iterations. + * @param maximalNumberOfIterations Maximal number of iterations. + * @param guarantee Guarantee. + * @return New status. + */ + SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const; + // A termination condition to be used (can be unset). std::unique_ptr> terminationCondition; diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index 1c5506bae..e60c3c6ba 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -199,16 +199,13 @@ namespace storm { // Update environment variables. ++iterations; - status = updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual); + status = this->updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual); // Potentially show progress. this->showProgressIterative(iterations); - if (storm::utility::resources::isTerminate()) { - status = SolverStatus::Aborted; - } } while (status == SolverStatus::InProgress); - reportStatus(status, iterations); + this->reportStatus(status, iterations); // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { @@ -331,7 +328,7 @@ namespace storm { // Update environment variables. std::swap(currentX, newX); ++iterations; - status = updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee); + status = this->updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee); // Potentially show progress. this->showProgressIterative(iterations); @@ -417,7 +414,7 @@ namespace storm { auto two = storm::utility::convertNumber(2.0); storm::utility::vector::applyPointwise(*lowerX, *upperX, x, [&two] (ValueType const& a, ValueType const& b) -> ValueType { return (a + b) / two; }); - reportStatus(statusIters.first, statusIters.second); + this->reportStatus(statusIters.first, statusIters.second); // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { @@ -503,7 +500,7 @@ namespace storm { std::swap(x, *currentX); } - reportStatus(result.status, result.iterations); + this->reportStatus(result.status, result.iterations); // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulerSet()) { @@ -666,17 +663,17 @@ namespace storm { ++iterations; doConvergenceCheck = !doConvergenceCheck; if (lowerStep) { - status = updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual); + status = this->updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual); } if (upperStep) { - status = updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual); + status = this->updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual); } // Potentially show progress. this->showProgressIterative(iterations); } - reportStatus(status, iterations); + this->reportStatus(status, iterations); // We take the means of the lower and upper bound so we guarantee the desired precision. ValueType two = storm::utility::convertNumber(2.0); @@ -761,7 +758,7 @@ namespace storm { this->A->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get()); } - reportStatus(status, iterations); + this->reportStatus(status, iterations); if (!this->isCachingEnabled()) { clearCache(); @@ -1064,7 +1061,7 @@ namespace storm { status = SolverStatus::MaximalIterationsExceeded; } - reportStatus(status, overallIterations); + this->reportStatus(status, overallIterations); return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly; } @@ -1103,32 +1100,6 @@ namespace storm { *choice = optimalRow - this->A->getRowGroupIndices()[group]; } } - - template - SolverStatus IterativeMinMaxLinearEquationSolver::updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const { - if (status != SolverStatus::Converged) { - if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) { - status = SolverStatus::TerminatedEarly; - } else if (iterations >= maximalNumberOfIterations) { - status = SolverStatus::MaximalIterationsExceeded; - } else if (storm::utility::resources::isTerminate()) { - status = SolverStatus::Aborted; - } - } - return status; - } - - template - void IterativeMinMaxLinearEquationSolver::reportStatus(SolverStatus status, uint64_t iterations) { - switch (status) { - case SolverStatus::Converged: STORM_LOG_TRACE("Iterative solver converged after " << iterations << " iterations."); break; - case SolverStatus::TerminatedEarly: STORM_LOG_TRACE("Iterative solver terminated early after " << iterations << " iterations."); break; - case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break; - case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break; - default: - STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly."); - } - } template void IterativeMinMaxLinearEquationSolver::clearCache() const { diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h index fe818ecd8..6582caefa 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h @@ -86,8 +86,6 @@ namespace storm { mutable std::unique_ptr> auxiliaryRowGroupVector2; // A.rowGroupCount() entries mutable std::unique_ptr> soundValueIterationHelper; - SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const; - static void reportStatus(SolverStatus status, uint64_t iterations); }; } diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index 067a2a5cf..57f89eb48 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -65,19 +65,22 @@ namespace storm { // Set up additional environment variables. uint_fast64_t iterations = 0; - bool converged = false; - bool terminate = false; + SolverStatus status = SolverStatus::InProgress; this->startMeasureProgress(); - while (!converged && !terminate && iterations < maxIter) { + while (status == SolverStatus::InProgress && iterations < maxIter) { A->performSuccessiveOverRelaxationStep(omega, x, b); // Now check if the process already converged within our precision. - converged = storm::utility::vector::equalModuloPrecision(*this->cachedRowVector, x, precision, relative); - terminate = this->terminateNow(x, SolverGuarantee::None); - + if (storm::utility::vector::equalModuloPrecision(*this->cachedRowVector, x, precision, relative)) { + status = SolverStatus::Converged; + } + if (this->terminateNow(x, SolverGuarantee::None)) { + status = SolverStatus::TerminatedEarly; + } + // If we did not yet converge, we need to backup the contents of x. - if (!converged) { + if (status != SolverStatus::Converged) { *this->cachedRowVector = x; } @@ -86,15 +89,19 @@ namespace storm { // Increase iteration count so we can abort if convergence is too slow. ++iterations; + + if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } } if (!this->isCachingEnabled()) { clearCache(); } - - this->logIterations(converged, terminate, iterations); - - return converged; + + this->reportStatus(status, iterations); + + return status == SolverStatus::Converged; } template @@ -127,20 +134,22 @@ namespace storm { // Set up additional environment variables. uint_fast64_t iterations = 0; - bool converged = false; - bool terminate = false; + SolverStatus status = SolverStatus::InProgress; this->startMeasureProgress(); - while (!converged && !terminate && iterations < maxIter) { + while (status == SolverStatus::InProgress && iterations < maxIter) { // Compute D^-1 * (b - LU * x) and store result in nextX. jacobiDecomposition->multiplier->multiply(env, *currentX, nullptr, *nextX); storm::utility::vector::subtractVectors(b, *nextX, *nextX); storm::utility::vector::multiplyVectorsPointwise(jacobiDecomposition->DVector, *nextX, *nextX); // Now check if the process already converged within our precision. - converged = storm::utility::vector::equalModuloPrecision(*currentX, *nextX, precision, relative); - terminate = this->terminateNow(*currentX, SolverGuarantee::None); - + if (storm::utility::vector::equalModuloPrecision(*currentX, *nextX, precision, relative)) { + status = SolverStatus::Converged; + } + if (this->terminateNow(*currentX, SolverGuarantee::None)) { + status = SolverStatus::TerminatedEarly; + } // Swap the two pointers as a preparation for the next iteration. std::swap(nextX, currentX); @@ -149,6 +158,10 @@ namespace storm { // Increase iteration count so we can abort if convergence is too slow. ++iterations; + + if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } } // If the last iteration did not write to the original x we have to swap the contents, because the @@ -160,10 +173,10 @@ namespace storm { if (!this->isCachingEnabled()) { clearCache(); } - - this->logIterations(converged, terminate, iterations); - return converged; + this->reportStatus(status, iterations); + + return status == SolverStatus::Converged; } template @@ -257,10 +270,10 @@ namespace storm { walkerChaeData->multiplier->multiply(env, *currentX, nullptr, currentAx); // (3) Perform iterations until convergence. - bool converged = false; + SolverStatus status = SolverStatus::InProgress; uint64_t iterations = 0; this->startMeasureProgress(); - while (!converged && iterations < maxIter) { + while (status == SolverStatus::InProgress && iterations < maxIter) { // Perform one Walker-Chae step. walkerChaeData->matrix.performWalkerChaeStep(*currentX, walkerChaeData->columnSums, walkerChaeData->b, currentAx, *nextX); @@ -268,7 +281,9 @@ namespace storm { walkerChaeData->multiplier->multiply(env, *nextX, nullptr, currentAx); // Check for convergence. - converged = storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound; + if (storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound) { + status = SolverStatus::Converged; + } // Swap the x vectors for the next iteration. std::swap(currentX, nextX); @@ -278,6 +293,10 @@ namespace storm { // Increase iteration count so we can abort if convergence is too slow. ++iterations; + + if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } } // If the last iteration did not write to the original x we have to swap the contents, because the @@ -296,13 +315,9 @@ namespace storm { clearCache(); } - if (converged) { - STORM_LOG_INFO("Iterative solver converged in " << iterations << " iterations."); - } else { - STORM_LOG_WARN("Iterative solver did not converge in " << iterations << " iterations."); - } + this->reportStatus(status, iterations); - return converged; + return status == SolverStatus::Converged; } template @@ -433,9 +448,8 @@ namespace storm { if (!this->multiplier) { this->multiplier = storm::solver::MultiplierFactory().create(env, *A); } - - bool converged = false; - bool terminate = false; + + SolverStatus status = SolverStatus::InProgress; uint64_t iterations = 0; bool doConvergenceCheck = true; bool useDiffs = this->hasRelevantValues() && !env.solver().native().isSymmetricUpdatesSet(); @@ -452,7 +466,7 @@ namespace storm { } uint64_t maxIter = env.solver().native().getMaximalNumberOfIterations(); this->startMeasureProgress(); - while (!converged && !terminate && iterations < maxIter) { + while (status == SolverStatus::InProgress && iterations < maxIter) { // Remember in which directions we took steps in this iteration. bool lowerStep = false; bool upperStep = false; @@ -537,24 +551,36 @@ namespace storm { // precision here. Doing so, we need to take the means of the lower and upper values later to guarantee // the original precision. if (this->hasRelevantValues()) { - converged = storm::utility::vector::equalModuloPrecision(*lowerX, *upperX, this->getRelevantValues(), precision, relative); + if (storm::utility::vector::equalModuloPrecision(*lowerX, *upperX, this->getRelevantValues(), precision, relative)) { + status = SolverStatus::Converged; + } } else { - converged = storm::utility::vector::equalModuloPrecision(*lowerX, *upperX, precision, relative); + if (storm::utility::vector::equalModuloPrecision(*lowerX, *upperX, precision, relative)) { + status = SolverStatus::Converged; + } } if (lowerStep) { - terminate |= this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual); + if (this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual)) { + status = SolverStatus::TerminatedEarly; + } } if (upperStep) { - terminate |= this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual); + if (this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual)) { + status = SolverStatus::TerminatedEarly; + } } } // Potentially show progress. this->showProgressIterative(iterations); + // Set up next iteration. ++iterations; doConvergenceCheck = !doConvergenceCheck; + if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } } // We take the means of the lower and upper bound so we guarantee the desired precision. @@ -570,9 +596,9 @@ namespace storm { if (!this->isCachingEnabled()) { clearCache(); } - this->logIterations(converged, terminate, iterations); + this->reportStatus(status, iterations); - return converged; + return status == SolverStatus::Converged; } @@ -603,35 +629,39 @@ namespace storm { relevantValuesPtr = &this->getRelevantValues(); } - bool converged = false; - bool terminate = false; + SolverStatus status = SolverStatus::InProgress; this->startMeasureProgress(); uint64_t iterations = 0; - while (!converged && iterations < env.solver().native().getMaximalNumberOfIterations()) { + while (status == SolverStatus::InProgress && iterations < env.solver().native().getMaximalNumberOfIterations()) { this->soundValueIterationHelper->performIterationStep(b); if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) { - converged = true; + status = SolverStatus::Converged; } // Check whether we terminate early. - terminate = this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition()); + if (this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition())) { + status = SolverStatus::TerminatedEarly; + } // Update environment variables. ++iterations; // Potentially show progress. this->showProgressIterative(iterations); + if (storm::utility::resources::isTerminate()) { + status = SolverStatus::Aborted; + } } this->soundValueIterationHelper->setSolutionVector(); - - this->logIterations(converged, terminate, iterations); - + + this->reportStatus(status, iterations); + if (!this->isCachingEnabled()) { clearCache(); } - return converged; + return status == SolverStatus::Converged; } template @@ -961,7 +991,7 @@ namespace storm { // Checked all values at this point. return true; } - + template void NativeLinearEquationSolver::logIterations(bool converged, bool terminate, uint64_t iterations) const { if (converged) { @@ -972,7 +1002,7 @@ namespace storm { STORM_LOG_WARN("Iterative solver did not converge in " << iterations << " iterations."); } } - + template NativeLinearEquationSolverMethod NativeLinearEquationSolver::getMethod(Environment const& env, bool isExactMode) const { // Adjust the method if none was specified and we want exact or sound computations diff --git a/src/storm/solver/StandardGameSolver.cpp b/src/storm/solver/StandardGameSolver.cpp index 812fdb3c6..39cff8567 100644 --- a/src/storm/solver/StandardGameSolver.cpp +++ b/src/storm/solver/StandardGameSolver.cpp @@ -230,7 +230,7 @@ namespace storm { status = updateStatusIfNotConverged(status, x, iterations, maxIter); } while (status == SolverStatus::InProgress); - reportStatus(status, iterations); + this->reportStatus(status, iterations); // If requested, we store the scheduler for retrieval. if (this->isTrackSchedulersSet() && !(providedPlayer1Choices && providedPlayer2Choices)) { @@ -330,7 +330,7 @@ namespace storm { status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter); } - reportStatus(status, iterations); + this->reportStatus(status, iterations); // If we performed an odd number of iterations, we need to swap the x and currentX, because the newest result // is currently stored in currentX, but x is the output vector. @@ -574,18 +574,6 @@ namespace storm { return status; } - template - void StandardGameSolver::reportStatus(SolverStatus status, uint64_t iterations) const { - switch (status) { - case SolverStatus::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break; - case SolverStatus::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break; - case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break; - case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break; - default: - STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly."); - } - } - template void StandardGameSolver::clearCache() const { multiplierPlayer2Matrix.reset(); diff --git a/src/storm/solver/StandardGameSolver.h b/src/storm/solver/StandardGameSolver.h index b69ec91c5..ac1196a52 100644 --- a/src/storm/solver/StandardGameSolver.h +++ b/src/storm/solver/StandardGameSolver.h @@ -56,8 +56,7 @@ namespace storm { mutable std::unique_ptr> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const; - void reportStatus(SolverStatus status, uint64_t iterations) const; - + /// The factory used to obtain linear equation solvers. std::unique_ptr> linearEquationSolverFactory;