From acd5a94162a1dea54903b00e85d92b428c95887e Mon Sep 17 00:00:00 2001 From: Matthias Volk Date: Thu, 5 Mar 2020 15:06:49 +0100 Subject: [PATCH] Use general SolverStatus in StandardGameSolver --- src/storm/solver/StandardGameSolver.cpp | 40 ++++++++++++------------- src/storm/solver/StandardGameSolver.h | 9 ++---- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/storm/solver/StandardGameSolver.cpp b/src/storm/solver/StandardGameSolver.cpp index 8fcc46a80..9f20b568b 100644 --- a/src/storm/solver/StandardGameSolver.cpp +++ b/src/storm/solver/StandardGameSolver.cpp @@ -6,17 +6,15 @@ #include "storm/solver/EliminationLinearEquationSolver.h" #include "storm/environment/solver/GameSolverEnvironment.h" - +#include "storm/exceptions/InvalidEnvironmentException.h" +#include "storm/exceptions/InvalidStateException.h" +#include "storm/exceptions/NotImplementedException.h" #include "storm/settings/SettingsManager.h" #include "storm/settings/modules/GeneralSettings.h" - #include "storm/utility/ConstantsComparator.h" #include "storm/utility/graph.h" #include "storm/utility/vector.h" #include "storm/utility/macros.h" -#include "storm/exceptions/InvalidEnvironmentException.h" -#include "storm/exceptions/InvalidStateException.h" -#include "storm/exceptions/NotImplementedException.h" namespace storm { namespace solver { @@ -187,7 +185,7 @@ namespace storm { } submatrixSolver->setCachingEnabled(true); - Status status = Status::InProgress; + SolverStatus status = SolverStatus::InProgress; uint64_t iterations = 0; do { submatrixSolver->solveEquations(environmentOfSolver, x, subB); @@ -197,7 +195,7 @@ namespace storm { // If the scheduler did not improve, we are done. if (!schedulerImproved) { - status = Status::Converged; + status = SolverStatus::Converged; } else { // Update the solver. getInducedMatrixVector(x, b, *player1Choices, *player2Choices, submatrix, subB); @@ -229,7 +227,7 @@ namespace storm { // Update environment variables. ++iterations; status = updateStatusIfNotConverged(status, x, iterations, maxIter); - } while (status == Status::InProgress); + } while (status == SolverStatus::InProgress); reportStatus(status, iterations); @@ -243,7 +241,7 @@ namespace storm { clearCache(); } - return status == Status::Converged || status == Status::TerminatedEarly; + return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly; } template @@ -314,15 +312,15 @@ namespace storm { // Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations. uint64_t iterations = 0; - Status status = Status::InProgress; - while (status == Status::InProgress) { + SolverStatus status = SolverStatus::InProgress; + while (status == SolverStatus::InProgress) { multiplyAndReduce(env, player1Dir, player2Dir, *currentX, &b, *multiplierPlayer2Matrix, reducedPlayer2Result, *newX, trackSchedulersInValueIteration ? (trackingSchedulersInProvidedStorage ? player1Choices : &this->player1SchedulerChoices.get()) : nullptr, trackSchedulersInValueIteration ? (trackingSchedulersInProvidedStorage ? player2Choices : &this->player2SchedulerChoices.get()) : nullptr); // Determine whether the method converged. if (storm::utility::vector::equalModuloPrecision(*currentX, *newX, precision, relative)) { - status = Status::Converged; + status = SolverStatus::Converged; } // Update environment variables. @@ -354,7 +352,7 @@ namespace storm { clearCache(); } - return (status == Status::Converged || status == Status::TerminatedEarly); + return (status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly); } template @@ -562,23 +560,23 @@ namespace storm { } template - typename StandardGameSolver::Status StandardGameSolver::updateStatusIfNotConverged(Status status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const { - if (status != Status::Converged) { + SolverStatus StandardGameSolver::updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const { + if (status != SolverStatus::Converged) { if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x)) { - status = Status::TerminatedEarly; + status = SolverStatus::TerminatedEarly; } else if (iterations >= maximalNumberOfIterations) { - status = Status::MaximalIterationsExceeded; + status = SolverStatus::MaximalIterationsExceeded; } } return status; } template - void StandardGameSolver::reportStatus(Status status, uint64_t iterations) const { + void StandardGameSolver::reportStatus(SolverStatus status, uint64_t iterations) const { switch (status) { - case Status::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break; - case Status::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break; - case Status::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break; + case SolverStatus::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break; + case SolverStatus::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break; + case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break; default: STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly."); } diff --git a/src/storm/solver/StandardGameSolver.h b/src/storm/solver/StandardGameSolver.h index 7c7519231..b69ec91c5 100644 --- a/src/storm/solver/StandardGameSolver.h +++ b/src/storm/solver/StandardGameSolver.h @@ -3,6 +3,7 @@ #include "storm/solver/LinearEquationSolver.h" #include "storm/solver/Multiplier.h" #include "storm/solver/GameSolver.h" +#include "storm/solver/SolverStatus.h" #include "SolverSelectionOptions.h" namespace storm { @@ -47,10 +48,6 @@ namespace storm { std::vector const& getPlayer1Grouping() const; uint64_t getNumberOfPlayer1States() const; uint64_t getNumberOfPlayer2States() const; - - enum class Status { - Converged, TerminatedEarly, MaximalIterationsExceeded, InProgress - }; // possibly cached data mutable std::unique_ptr> multiplierPlayer2Matrix; @@ -58,8 +55,8 @@ namespace storm { mutable std::unique_ptr> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries mutable std::unique_ptr> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries - Status updateStatusIfNotConverged(Status status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const; - void reportStatus(Status status, uint64_t iterations) const; + SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const; + void reportStatus(SolverStatus status, uint64_t iterations) const; /// The factory used to obtain linear equation solvers. std::unique_ptr> linearEquationSolverFactory;