Browse Source

Moved reportStatus() and updateStatusIfNotConverged() to AbstractEquationSolver

main
Matthias Volk 5 years ago
parent
commit
b745b10b77
  1. 65
      src/storm/solver/AbstractEquationSolver.cpp
  2. 20
      src/storm/solver/AbstractEquationSolver.h
  3. 49
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  4. 2
      src/storm/solver/IterativeMinMaxLinearEquationSolver.h
  5. 112
      src/storm/solver/NativeLinearEquationSolver.cpp
  6. 16
      src/storm/solver/StandardGameSolver.cpp
  7. 1
      src/storm/solver/StandardGameSolver.h

65
src/storm/solver/AbstractEquationSolver.cpp

@ -2,14 +2,15 @@
#include "storm/adapters/RationalNumberAdapter.h" #include "storm/adapters/RationalNumberAdapter.h"
#include "storm/adapters/RationalFunctionAdapter.h" #include "storm/adapters/RationalFunctionAdapter.h"
#include "storm/exceptions/InvalidOperationException.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/UnmetRequirementException.h"
#include "storm/settings/SettingsManager.h" #include "storm/settings/SettingsManager.h"
#include "storm/settings/modules/GeneralSettings.h" #include "storm/settings/modules/GeneralSettings.h"
#include "storm/utility/constants.h" #include "storm/utility/constants.h"
#include "storm/utility/macros.h" #include "storm/utility/macros.h"
#include "storm/exceptions/UnmetRequirementException.h"
#include "storm/exceptions/InvalidOperationException.h"
#include "storm/utility/SignalHandler.h"
namespace storm { namespace storm {
namespace solver { namespace solver {
@ -273,6 +274,62 @@ namespace storm {
} }
} }
template<typename ValueType>
void AbstractEquationSolver<ValueType>::reportStatus(SolverStatus status, boost::optional<uint64_t> const& iterations) const {
if (iterations) {
switch (status) {
case SolverStatus::Converged:
STORM_LOG_TRACE("Iterative solver converged after " << iterations.get() << " iterations.");
break;
case SolverStatus::TerminatedEarly:
STORM_LOG_TRACE("Iterative solver terminated early after " << iterations.get() << " iterations.");
break;
case SolverStatus::MaximalIterationsExceeded:
STORM_LOG_WARN("Iterative solver did not converge after " << iterations.get() << " iterations.");
break;
case SolverStatus::Aborted:
STORM_LOG_WARN("Iterative solver was aborted.");
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
} else {
switch (status) {
case SolverStatus::Converged:
STORM_LOG_TRACE("Solver converged.");
break;
case SolverStatus::TerminatedEarly:
STORM_LOG_TRACE("Solver terminated early.");
break;
case SolverStatus::MaximalIterationsExceeded:
STORM_LOG_ASSERT(false, "Non-iterative solver should not exceed maximal number of iterations.");
STORM_LOG_WARN("Solver did not converge.");
break;
case SolverStatus::Aborted:
STORM_LOG_WARN("Solver was aborted.");
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Solver terminated unexpectedly.");
}
}
}
template<typename ValueType>
SolverStatus AbstractEquationSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
} else if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
return status;
}
template class AbstractEquationSolver<double>; template class AbstractEquationSolver<double>;
template class AbstractEquationSolver<float>; template class AbstractEquationSolver<float>;

20
src/storm/solver/AbstractEquationSolver.h

@ -6,9 +6,11 @@
#include <iostream> #include <iostream>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include "storm/solver/SolverStatus.h"
#include "storm/solver/TerminationCondition.h" #include "storm/solver/TerminationCondition.h"
#include "storm/utility/ProgressMeasurement.h" #include "storm/utility/ProgressMeasurement.h"
namespace storm { namespace storm {
namespace solver { namespace solver {
@ -191,6 +193,24 @@ namespace storm {
void createUpperBoundsVector(std::unique_ptr<std::vector<ValueType>>& upperBoundsVector, uint64_t length) const; void createUpperBoundsVector(std::unique_ptr<std::vector<ValueType>>& upperBoundsVector, uint64_t length) const;
void createLowerBoundsVector(std::vector<ValueType>& lowerBoundsVector) const; void createLowerBoundsVector(std::vector<ValueType>& lowerBoundsVector) const;
/*!
* Report the current status of the solver.
* @param status Solver status.
* @param iterations Number of iterations (if solver is iterative).
*/
void reportStatus(SolverStatus status, boost::optional<uint64_t> const& iterations = boost::none) const;
/*!
* Update the status of the solver with respect to convergence, early termination, abortion, etc.
* @param status Current status.
* @param x Vector x.
* @param iterations Current number of iterations.
* @param maximalNumberOfIterations Maximal number of iterations.
* @param guarantee Guarantee.
* @return New status.
*/
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
// A termination condition to be used (can be unset). // A termination condition to be used (can be unset).
std::unique_ptr<TerminationCondition<ValueType>> terminationCondition; std::unique_ptr<TerminationCondition<ValueType>> terminationCondition;

49
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -199,16 +199,13 @@ namespace storm {
// Update environment variables. // Update environment variables.
++iterations; ++iterations;
status = updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual);
status = this->updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual);
// Potentially show progress. // Potentially show progress.
this->showProgressIterative(iterations); this->showProgressIterative(iterations);
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} while (status == SolverStatus::InProgress); } while (status == SolverStatus::InProgress);
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If requested, we store the scheduler for retrieval. // If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) { if (this->isTrackSchedulerSet()) {
@ -331,7 +328,7 @@ namespace storm {
// Update environment variables. // Update environment variables.
std::swap(currentX, newX); std::swap(currentX, newX);
++iterations; ++iterations;
status = updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee);
status = this->updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee);
// Potentially show progress. // Potentially show progress.
this->showProgressIterative(iterations); this->showProgressIterative(iterations);
@ -417,7 +414,7 @@ namespace storm {
auto two = storm::utility::convertNumber<ValueType>(2.0); auto two = storm::utility::convertNumber<ValueType>(2.0);
storm::utility::vector::applyPointwise<ValueType, ValueType, ValueType>(*lowerX, *upperX, x, [&two] (ValueType const& a, ValueType const& b) -> ValueType { return (a + b) / two; }); storm::utility::vector::applyPointwise<ValueType, ValueType, ValueType>(*lowerX, *upperX, x, [&two] (ValueType const& a, ValueType const& b) -> ValueType { return (a + b) / two; });
reportStatus(statusIters.first, statusIters.second);
this->reportStatus(statusIters.first, statusIters.second);
// If requested, we store the scheduler for retrieval. // If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) { if (this->isTrackSchedulerSet()) {
@ -503,7 +500,7 @@ namespace storm {
std::swap(x, *currentX); std::swap(x, *currentX);
} }
reportStatus(result.status, result.iterations);
this->reportStatus(result.status, result.iterations);
// If requested, we store the scheduler for retrieval. // If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) { if (this->isTrackSchedulerSet()) {
@ -666,17 +663,17 @@ namespace storm {
++iterations; ++iterations;
doConvergenceCheck = !doConvergenceCheck; doConvergenceCheck = !doConvergenceCheck;
if (lowerStep) { if (lowerStep) {
status = updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual);
status = this->updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual);
} }
if (upperStep) { if (upperStep) {
status = updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual);
status = this->updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual);
} }
// Potentially show progress. // Potentially show progress.
this->showProgressIterative(iterations); this->showProgressIterative(iterations);
} }
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// We take the means of the lower and upper bound so we guarantee the desired precision. // We take the means of the lower and upper bound so we guarantee the desired precision.
ValueType two = storm::utility::convertNumber<ValueType>(2.0); ValueType two = storm::utility::convertNumber<ValueType>(2.0);
@ -761,7 +758,7 @@ namespace storm {
this->A->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get()); this->A->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get());
} }
reportStatus(status, iterations);
this->reportStatus(status, iterations);
if (!this->isCachingEnabled()) { if (!this->isCachingEnabled()) {
clearCache(); clearCache();
@ -1064,7 +1061,7 @@ namespace storm {
status = SolverStatus::MaximalIterationsExceeded; status = SolverStatus::MaximalIterationsExceeded;
} }
reportStatus(status, overallIterations);
this->reportStatus(status, overallIterations);
return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly; return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
} }
@ -1104,32 +1101,6 @@ namespace storm {
} }
} }
template<typename ValueType>
SolverStatus IterativeMinMaxLinearEquationSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
} else if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
return status;
}
template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::reportStatus(SolverStatus status, uint64_t iterations) {
switch (status) {
case SolverStatus::Converged: STORM_LOG_TRACE("Iterative solver converged after " << iterations << " iterations."); break;
case SolverStatus::TerminatedEarly: STORM_LOG_TRACE("Iterative solver terminated early after " << iterations << " iterations."); break;
case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
}
template<typename ValueType> template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::clearCache() const { void IterativeMinMaxLinearEquationSolver<ValueType>::clearCache() const {
multiplierA.reset(); multiplierA.reset();

2
src/storm/solver/IterativeMinMaxLinearEquationSolver.h

@ -86,8 +86,6 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector2; // A.rowGroupCount() entries mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector2; // A.rowGroupCount() entries
mutable std::unique_ptr<storm::solver::helper::SoundValueIterationHelper<ValueType>> soundValueIterationHelper; mutable std::unique_ptr<storm::solver::helper::SoundValueIterationHelper<ValueType>> soundValueIterationHelper;
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
static void reportStatus(SolverStatus status, uint64_t iterations);
}; };
} }

112
src/storm/solver/NativeLinearEquationSolver.cpp

@ -65,19 +65,22 @@ namespace storm {
// Set up additional environment variables. // Set up additional environment variables.
uint_fast64_t iterations = 0; uint_fast64_t iterations = 0;
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress(); this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
A->performSuccessiveOverRelaxationStep(omega, x, b); A->performSuccessiveOverRelaxationStep(omega, x, b);
// Now check if the process already converged within our precision. // Now check if the process already converged within our precision.
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*this->cachedRowVector, x, precision, relative);
terminate = this->terminateNow(x, SolverGuarantee::None);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*this->cachedRowVector, x, precision, relative)) {
status = SolverStatus::Converged;
}
if (this->terminateNow(x, SolverGuarantee::None)) {
status = SolverStatus::TerminatedEarly;
}
// If we did not yet converge, we need to backup the contents of x. // If we did not yet converge, we need to backup the contents of x.
if (!converged) {
if (status != SolverStatus::Converged) {
*this->cachedRowVector = x; *this->cachedRowVector = x;
} }
@ -86,15 +89,19 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow. // Increase iteration count so we can abort if convergence is too slow.
++iterations; ++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} }
if (!this->isCachingEnabled()) { if (!this->isCachingEnabled()) {
clearCache(); clearCache();
} }
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
} }
template<typename ValueType> template<typename ValueType>
@ -127,20 +134,22 @@ namespace storm {
// Set up additional environment variables. // Set up additional environment variables.
uint_fast64_t iterations = 0; uint_fast64_t iterations = 0;
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress(); this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Compute D^-1 * (b - LU * x) and store result in nextX. // Compute D^-1 * (b - LU * x) and store result in nextX.
jacobiDecomposition->multiplier->multiply(env, *currentX, nullptr, *nextX); jacobiDecomposition->multiplier->multiply(env, *currentX, nullptr, *nextX);
storm::utility::vector::subtractVectors(b, *nextX, *nextX); storm::utility::vector::subtractVectors(b, *nextX, *nextX);
storm::utility::vector::multiplyVectorsPointwise(jacobiDecomposition->DVector, *nextX, *nextX); storm::utility::vector::multiplyVectorsPointwise(jacobiDecomposition->DVector, *nextX, *nextX);
// Now check if the process already converged within our precision. // Now check if the process already converged within our precision.
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *nextX, precision, relative);
terminate = this->terminateNow(*currentX, SolverGuarantee::None);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *nextX, precision, relative)) {
status = SolverStatus::Converged;
}
if (this->terminateNow(*currentX, SolverGuarantee::None)) {
status = SolverStatus::TerminatedEarly;
}
// Swap the two pointers as a preparation for the next iteration. // Swap the two pointers as a preparation for the next iteration.
std::swap(nextX, currentX); std::swap(nextX, currentX);
@ -149,6 +158,10 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow. // Increase iteration count so we can abort if convergence is too slow.
++iterations; ++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} }
// If the last iteration did not write to the original x we have to swap the contents, because the // If the last iteration did not write to the original x we have to swap the contents, because the
@ -161,9 +174,9 @@ namespace storm {
clearCache(); clearCache();
} }
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
} }
template<typename ValueType> template<typename ValueType>
@ -257,10 +270,10 @@ namespace storm {
walkerChaeData->multiplier->multiply(env, *currentX, nullptr, currentAx); walkerChaeData->multiplier->multiply(env, *currentX, nullptr, currentAx);
// (3) Perform iterations until convergence. // (3) Perform iterations until convergence.
bool converged = false;
SolverStatus status = SolverStatus::InProgress;
uint64_t iterations = 0; uint64_t iterations = 0;
this->startMeasureProgress(); this->startMeasureProgress();
while (!converged && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Perform one Walker-Chae step. // Perform one Walker-Chae step.
walkerChaeData->matrix.performWalkerChaeStep(*currentX, walkerChaeData->columnSums, walkerChaeData->b, currentAx, *nextX); walkerChaeData->matrix.performWalkerChaeStep(*currentX, walkerChaeData->columnSums, walkerChaeData->b, currentAx, *nextX);
@ -268,7 +281,9 @@ namespace storm {
walkerChaeData->multiplier->multiply(env, *nextX, nullptr, currentAx); walkerChaeData->multiplier->multiply(env, *nextX, nullptr, currentAx);
// Check for convergence. // Check for convergence.
converged = storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound;
if (storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound) {
status = SolverStatus::Converged;
}
// Swap the x vectors for the next iteration. // Swap the x vectors for the next iteration.
std::swap(currentX, nextX); std::swap(currentX, nextX);
@ -278,6 +293,10 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow. // Increase iteration count so we can abort if convergence is too slow.
++iterations; ++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} }
// If the last iteration did not write to the original x we have to swap the contents, because the // If the last iteration did not write to the original x we have to swap the contents, because the
@ -296,13 +315,9 @@ namespace storm {
clearCache(); clearCache();
} }
if (converged) {
STORM_LOG_INFO("Iterative solver converged in " << iterations << " iterations.");
} else {
STORM_LOG_WARN("Iterative solver did not converge in " << iterations << " iterations.");
}
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
} }
template<typename ValueType> template<typename ValueType>
@ -434,8 +449,7 @@ namespace storm {
this->multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, *A); this->multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, *A);
} }
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
uint64_t iterations = 0; uint64_t iterations = 0;
bool doConvergenceCheck = true; bool doConvergenceCheck = true;
bool useDiffs = this->hasRelevantValues() && !env.solver().native().isSymmetricUpdatesSet(); bool useDiffs = this->hasRelevantValues() && !env.solver().native().isSymmetricUpdatesSet();
@ -452,7 +466,7 @@ namespace storm {
} }
uint64_t maxIter = env.solver().native().getMaximalNumberOfIterations(); uint64_t maxIter = env.solver().native().getMaximalNumberOfIterations();
this->startMeasureProgress(); this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Remember in which directions we took steps in this iteration. // Remember in which directions we took steps in this iteration.
bool lowerStep = false; bool lowerStep = false;
bool upperStep = false; bool upperStep = false;
@ -537,24 +551,36 @@ namespace storm {
// precision here. Doing so, we need to take the means of the lower and upper values later to guarantee // precision here. Doing so, we need to take the means of the lower and upper values later to guarantee
// the original precision. // the original precision.
if (this->hasRelevantValues()) { if (this->hasRelevantValues()) {
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, this->getRelevantValues(), precision, relative);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, this->getRelevantValues(), precision, relative)) {
status = SolverStatus::Converged;
}
} else { } else {
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, precision, relative);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, precision, relative)) {
status = SolverStatus::Converged;
}
} }
if (lowerStep) { if (lowerStep) {
terminate |= this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual);
if (this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual)) {
status = SolverStatus::TerminatedEarly;
}
} }
if (upperStep) { if (upperStep) {
terminate |= this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual);
if (this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual)) {
status = SolverStatus::TerminatedEarly;
}
} }
} }
// Potentially show progress. // Potentially show progress.
this->showProgressIterative(iterations); this->showProgressIterative(iterations);
// Set up next iteration. // Set up next iteration.
++iterations; ++iterations;
doConvergenceCheck = !doConvergenceCheck; doConvergenceCheck = !doConvergenceCheck;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} }
// We take the means of the lower and upper bound so we guarantee the desired precision. // We take the means of the lower and upper bound so we guarantee the desired precision.
@ -570,9 +596,9 @@ namespace storm {
if (!this->isCachingEnabled()) { if (!this->isCachingEnabled()) {
clearCache(); clearCache();
} }
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
} }
@ -603,35 +629,39 @@ namespace storm {
relevantValuesPtr = &this->getRelevantValues(); relevantValuesPtr = &this->getRelevantValues();
} }
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress(); this->startMeasureProgress();
uint64_t iterations = 0; uint64_t iterations = 0;
while (!converged && iterations < env.solver().native().getMaximalNumberOfIterations()) {
while (status == SolverStatus::InProgress && iterations < env.solver().native().getMaximalNumberOfIterations()) {
this->soundValueIterationHelper->performIterationStep(b); this->soundValueIterationHelper->performIterationStep(b);
if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) { if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) {
converged = true;
status = SolverStatus::Converged;
} }
// Check whether we terminate early. // Check whether we terminate early.
terminate = this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition());
if (this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition())) {
status = SolverStatus::TerminatedEarly;
}
// Update environment variables. // Update environment variables.
++iterations; ++iterations;
// Potentially show progress. // Potentially show progress.
this->showProgressIterative(iterations); this->showProgressIterative(iterations);
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} }
this->soundValueIterationHelper->setSolutionVector(); this->soundValueIterationHelper->setSolutionVector();
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
if (!this->isCachingEnabled()) { if (!this->isCachingEnabled()) {
clearCache(); clearCache();
} }
return converged;
return status == SolverStatus::Converged;
} }
template<typename ValueType> template<typename ValueType>

16
src/storm/solver/StandardGameSolver.cpp

@ -230,7 +230,7 @@ namespace storm {
status = updateStatusIfNotConverged(status, x, iterations, maxIter); status = updateStatusIfNotConverged(status, x, iterations, maxIter);
} while (status == SolverStatus::InProgress); } while (status == SolverStatus::InProgress);
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If requested, we store the scheduler for retrieval. // If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulersSet() && !(providedPlayer1Choices && providedPlayer2Choices)) { if (this->isTrackSchedulersSet() && !(providedPlayer1Choices && providedPlayer2Choices)) {
@ -330,7 +330,7 @@ namespace storm {
status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter); status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter);
} }
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If we performed an odd number of iterations, we need to swap the x and currentX, because the newest result // If we performed an odd number of iterations, we need to swap the x and currentX, because the newest result
// is currently stored in currentX, but x is the output vector. // is currently stored in currentX, but x is the output vector.
@ -574,18 +574,6 @@ namespace storm {
return status; return status;
} }
template<typename ValueType>
void StandardGameSolver<ValueType>::reportStatus(SolverStatus status, uint64_t iterations) const {
switch (status) {
case SolverStatus::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break;
case SolverStatus::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break;
case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
}
template<typename ValueType> template<typename ValueType>
void StandardGameSolver<ValueType>::clearCache() const { void StandardGameSolver<ValueType>::clearCache() const {
multiplierPlayer2Matrix.reset(); multiplierPlayer2Matrix.reset();

1
src/storm/solver/StandardGameSolver.h

@ -56,7 +56,6 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const; SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
void reportStatus(SolverStatus status, uint64_t iterations) const;
/// The factory used to obtain linear equation solvers. /// The factory used to obtain linear equation solvers.
std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory; std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory;

Loading…
Cancel
Save