Browse Source

Use general SolverStatus in StandardGameSolver

main
Matthias Volk 5 years ago
parent
commit
acd5a94162
  1. 40
      src/storm/solver/StandardGameSolver.cpp
  2. 9
      src/storm/solver/StandardGameSolver.h

40
src/storm/solver/StandardGameSolver.cpp

@ -6,17 +6,15 @@
#include "storm/solver/EliminationLinearEquationSolver.h"
#include "storm/environment/solver/GameSolverEnvironment.h"
#include "storm/exceptions/InvalidEnvironmentException.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/NotImplementedException.h"
#include "storm/settings/SettingsManager.h"
#include "storm/settings/modules/GeneralSettings.h"
#include "storm/utility/ConstantsComparator.h"
#include "storm/utility/graph.h"
#include "storm/utility/vector.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/InvalidEnvironmentException.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/NotImplementedException.h"
namespace storm {
namespace solver {
@ -187,7 +185,7 @@ namespace storm {
}
submatrixSolver->setCachingEnabled(true);
Status status = Status::InProgress;
SolverStatus status = SolverStatus::InProgress;
uint64_t iterations = 0;
do {
submatrixSolver->solveEquations(environmentOfSolver, x, subB);
@ -197,7 +195,7 @@ namespace storm {
// If the scheduler did not improve, we are done.
if (!schedulerImproved) {
status = Status::Converged;
status = SolverStatus::Converged;
} else {
// Update the solver.
getInducedMatrixVector(x, b, *player1Choices, *player2Choices, submatrix, subB);
@ -229,7 +227,7 @@ namespace storm {
// Update environment variables.
++iterations;
status = updateStatusIfNotConverged(status, x, iterations, maxIter);
} while (status == Status::InProgress);
} while (status == SolverStatus::InProgress);
reportStatus(status, iterations);
@ -243,7 +241,7 @@ namespace storm {
clearCache();
}
return status == Status::Converged || status == Status::TerminatedEarly;
return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
}
template<typename ValueType>
@ -314,15 +312,15 @@ namespace storm {
// Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations.
uint64_t iterations = 0;
Status status = Status::InProgress;
while (status == Status::InProgress) {
SolverStatus status = SolverStatus::InProgress;
while (status == SolverStatus::InProgress) {
multiplyAndReduce(env, player1Dir, player2Dir, *currentX, &b, *multiplierPlayer2Matrix, reducedPlayer2Result, *newX,
trackSchedulersInValueIteration ? (trackingSchedulersInProvidedStorage ? player1Choices : &this->player1SchedulerChoices.get()) : nullptr,
trackSchedulersInValueIteration ? (trackingSchedulersInProvidedStorage ? player2Choices : &this->player2SchedulerChoices.get()) : nullptr);
// Determine whether the method converged.
if (storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *newX, precision, relative)) {
status = Status::Converged;
status = SolverStatus::Converged;
}
// Update environment variables.
@ -354,7 +352,7 @@ namespace storm {
clearCache();
}
return (status == Status::Converged || status == Status::TerminatedEarly);
return (status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly);
}
template<typename ValueType>
@ -562,23 +560,23 @@ namespace storm {
}
template<typename ValueType>
typename StandardGameSolver<ValueType>::Status StandardGameSolver<ValueType>::updateStatusIfNotConverged(Status status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const {
if (status != Status::Converged) {
SolverStatus StandardGameSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x)) {
status = Status::TerminatedEarly;
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = Status::MaximalIterationsExceeded;
status = SolverStatus::MaximalIterationsExceeded;
}
}
return status;
}
template<typename ValueType>
void StandardGameSolver<ValueType>::reportStatus(Status status, uint64_t iterations) const {
void StandardGameSolver<ValueType>::reportStatus(SolverStatus status, uint64_t iterations) const {
switch (status) {
case Status::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break;
case Status::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break;
case Status::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
case SolverStatus::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break;
case SolverStatus::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break;
case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}

9
src/storm/solver/StandardGameSolver.h

@ -3,6 +3,7 @@
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/Multiplier.h"
#include "storm/solver/GameSolver.h"
#include "storm/solver/SolverStatus.h"
#include "SolverSelectionOptions.h"
namespace storm {
@ -47,10 +48,6 @@ namespace storm {
std::vector<uint64_t> const& getPlayer1Grouping() const;
uint64_t getNumberOfPlayer1States() const;
uint64_t getNumberOfPlayer2States() const;
enum class Status {
Converged, TerminatedEarly, MaximalIterationsExceeded, InProgress
};
// possibly cached data
mutable std::unique_ptr<storm::solver::Multiplier<ValueType>> multiplierPlayer2Matrix;
@ -58,8 +55,8 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries
Status updateStatusIfNotConverged(Status status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
void reportStatus(Status status, uint64_t iterations) const;
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
void reportStatus(SolverStatus status, uint64_t iterations) const;
/// The factory used to obtain linear equation solvers.
std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory;

Loading…
Cancel
Save