diff --git a/src/storm/solver/AbstractEquationSolver.cpp b/src/storm/solver/AbstractEquationSolver.cpp index 031599847..e7583e346 100644 --- a/src/storm/solver/AbstractEquationSolver.cpp +++ b/src/storm/solver/AbstractEquationSolver.cpp @@ -317,9 +317,9 @@ namespace storm { template - SolverStatus AbstractEquationSolver::updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const { + SolverStatus AbstractEquationSolver::updateStatus(SolverStatus status, bool earlyTermination, uint64_t iterations, uint64_t maximalNumberOfIterations) const { if (status != SolverStatus::Converged) { - if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) { + if (earlyTermination) { status = SolverStatus::TerminatedEarly; } else if (iterations >= maximalNumberOfIterations) { status = SolverStatus::MaximalIterationsExceeded; @@ -329,7 +329,13 @@ namespace storm { } return status; } - + + template + SolverStatus AbstractEquationSolver::updateStatus(SolverStatus status, std::vector const& x, SolverGuarantee const& guarantee, uint64_t iterations, uint64_t maximalNumberOfIterations) const { + return this->updateStatus(status, this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee), iterations, maximalNumberOfIterations); + } + + template class AbstractEquationSolver; template class AbstractEquationSolver; diff --git a/src/storm/solver/AbstractEquationSolver.h b/src/storm/solver/AbstractEquationSolver.h index 00fd7db9b..8e303f042 100644 --- a/src/storm/solver/AbstractEquationSolver.h +++ b/src/storm/solver/AbstractEquationSolver.h @@ -203,13 +203,23 @@ namespace storm { /*! * Update the status of the solver with respect to convergence, early termination, abortion, etc. * @param status Current status. - * @param x Vector x. + * @param x Vector for terminatation condition. + * @param guarantee Guarentee for termination condition. * @param iterations Current number of iterations. * @param maximalNumberOfIterations Maximal number of iterations. - * @param guarantee Guarantee. * @return New status. */ - SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const; + SolverStatus updateStatus(SolverStatus status, std::vector const& x, SolverGuarantee const& guarantee, uint64_t iterations, uint64_t maximalNumberOfIterations) const; + + /*! + * Update the status of the solver with respect to convergence, early termination, abortion, etc. + * @param status Current status. + * @param earlyTermination Flag indicating if the solver can be terminated early. + * @param iterations Current number of iterations. + * @param maximalNumberOfIterations Maximal number of iterations. + * @return New status. + */ + SolverStatus updateStatus(SolverStatus status, bool earlyTermination, uint64_t iterations, uint64_t maximalNumberOfIterations) const; // A termination condition to be used (can be unset). std::unique_ptr> terminationCondition; diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index e60c3c6ba..7684a12f3 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -199,7 +199,7 @@ namespace storm { // Update environment variables. ++iterations; - status = this->updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual); + status = this->updateStatus(status, x, dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations()); // Potentially show progress. this->showProgressIterative(iterations); @@ -328,7 +328,7 @@ namespace storm { // Update environment variables. std::swap(currentX, newX); ++iterations; - status = this->updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee); + status = this->updateStatus(status, *currentX, guarantee, iterations, maximalNumberOfIterations); // Potentially show progress. this->showProgressIterative(iterations); @@ -663,10 +663,10 @@ namespace storm { ++iterations; doConvergenceCheck = !doConvergenceCheck; if (lowerStep) { - status = this->updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual); + status = this->updateStatus(status, *lowerX, SolverGuarantee::LessOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations()); } if (upperStep) { - status = this->updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual); + status = this->updateStatus(status, *upperX, SolverGuarantee::GreaterOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations()); } // Potentially show progress. diff --git a/src/storm/solver/StandardGameSolver.cpp b/src/storm/solver/StandardGameSolver.cpp index 39cff8567..9f004eb75 100644 --- a/src/storm/solver/StandardGameSolver.cpp +++ b/src/storm/solver/StandardGameSolver.cpp @@ -227,7 +227,7 @@ namespace storm { // Update environment variables. ++iterations; - status = updateStatusIfNotConverged(status, x, iterations, maxIter); + status = this->updateStatus(status, x, SolverGuarantee::None, iterations, maxIter); } while (status == SolverStatus::InProgress); this->reportStatus(status, iterations); @@ -327,7 +327,7 @@ namespace storm { // Update environment variables. std::swap(currentX, newX); ++iterations; - status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter); + status = this->updateStatus(status, *currentX, SolverGuarantee::None, iterations, maxIter); } this->reportStatus(status, iterations); @@ -559,21 +559,7 @@ namespace storm { uint64_t StandardGameSolver::getNumberOfPlayer2States() const { return this->player2Matrix.getRowGroupCount(); } - - template - SolverStatus StandardGameSolver::updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const { - if (status != SolverStatus::Converged) { - if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x)) { - status = SolverStatus::TerminatedEarly; - } else if (iterations >= maximalNumberOfIterations) { - status = SolverStatus::MaximalIterationsExceeded; - } else if (storm::utility::resources::isTerminate()) { - status = SolverStatus::Aborted; - } - } - return status; - } - + template void StandardGameSolver::clearCache() const { multiplierPlayer2Matrix.reset(); diff --git a/src/storm/solver/StandardGameSolver.h b/src/storm/solver/StandardGameSolver.h index ac1196a52..381e445dd 100644 --- a/src/storm/solver/StandardGameSolver.h +++ b/src/storm/solver/StandardGameSolver.h @@ -55,8 +55,6 @@ namespace storm { mutable std::unique_ptr> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries mutable std::unique_ptr> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries - SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const; - /// The factory used to obtain linear equation solvers. std::unique_ptr> linearEquationSolverFactory;