Browse Source

Moved reportStatus() and updateStatusIfNotConverged() to AbstractEquationSolver

tempestpy_adaptions
Matthias Volk 5 years ago
parent
commit
b745b10b77
  1. 65
      src/storm/solver/AbstractEquationSolver.cpp
  2. 20
      src/storm/solver/AbstractEquationSolver.h
  3. 49
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  4. 2
      src/storm/solver/IterativeMinMaxLinearEquationSolver.h
  5. 130
      src/storm/solver/NativeLinearEquationSolver.cpp
  6. 16
      src/storm/solver/StandardGameSolver.cpp
  7. 3
      src/storm/solver/StandardGameSolver.h

65
src/storm/solver/AbstractEquationSolver.cpp

@ -2,14 +2,15 @@
#include "storm/adapters/RationalNumberAdapter.h"
#include "storm/adapters/RationalFunctionAdapter.h"
#include "storm/exceptions/InvalidOperationException.h"
#include "storm/exceptions/InvalidStateException.h"
#include "storm/exceptions/UnmetRequirementException.h"
#include "storm/settings/SettingsManager.h"
#include "storm/settings/modules/GeneralSettings.h"
#include "storm/utility/constants.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/UnmetRequirementException.h"
#include "storm/exceptions/InvalidOperationException.h"
#include "storm/utility/SignalHandler.h"
namespace storm {
namespace solver {
@ -272,6 +273,62 @@ namespace storm {
this->progressMeasurement->updateProgress(iteration);
}
}
template<typename ValueType>
void AbstractEquationSolver<ValueType>::reportStatus(SolverStatus status, boost::optional<uint64_t> const& iterations) const {
if (iterations) {
switch (status) {
case SolverStatus::Converged:
STORM_LOG_TRACE("Iterative solver converged after " << iterations.get() << " iterations.");
break;
case SolverStatus::TerminatedEarly:
STORM_LOG_TRACE("Iterative solver terminated early after " << iterations.get() << " iterations.");
break;
case SolverStatus::MaximalIterationsExceeded:
STORM_LOG_WARN("Iterative solver did not converge after " << iterations.get() << " iterations.");
break;
case SolverStatus::Aborted:
STORM_LOG_WARN("Iterative solver was aborted.");
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
} else {
switch (status) {
case SolverStatus::Converged:
STORM_LOG_TRACE("Solver converged.");
break;
case SolverStatus::TerminatedEarly:
STORM_LOG_TRACE("Solver terminated early.");
break;
case SolverStatus::MaximalIterationsExceeded:
STORM_LOG_ASSERT(false, "Non-iterative solver should not exceed maximal number of iterations.");
STORM_LOG_WARN("Solver did not converge.");
break;
case SolverStatus::Aborted:
STORM_LOG_WARN("Solver was aborted.");
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Solver terminated unexpectedly.");
}
}
}
template<typename ValueType>
SolverStatus AbstractEquationSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
} else if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
return status;
}
template class AbstractEquationSolver<double>;
template class AbstractEquationSolver<float>;

20
src/storm/solver/AbstractEquationSolver.h

@ -6,9 +6,11 @@
#include <iostream>
#include <boost/optional.hpp>
#include "storm/solver/SolverStatus.h"
#include "storm/solver/TerminationCondition.h"
#include "storm/utility/ProgressMeasurement.h"
namespace storm {
namespace solver {
@ -191,6 +193,24 @@ namespace storm {
void createUpperBoundsVector(std::unique_ptr<std::vector<ValueType>>& upperBoundsVector, uint64_t length) const;
void createLowerBoundsVector(std::vector<ValueType>& lowerBoundsVector) const;
/*!
* Report the current status of the solver.
* @param status Solver status.
* @param iterations Number of iterations (if solver is iterative).
*/
void reportStatus(SolverStatus status, boost::optional<uint64_t> const& iterations = boost::none) const;
/*!
* Update the status of the solver with respect to convergence, early termination, abortion, etc.
* @param status Current status.
* @param x Vector x.
* @param iterations Current number of iterations.
* @param maximalNumberOfIterations Maximal number of iterations.
* @param guarantee Guarantee.
* @return New status.
*/
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
// A termination condition to be used (can be unset).
std::unique_ptr<TerminationCondition<ValueType>> terminationCondition;

49
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -199,16 +199,13 @@ namespace storm {
// Update environment variables.
++iterations;
status = updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual);
status = this->updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual);
// Potentially show progress.
this->showProgressIterative(iterations);
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
} while (status == SolverStatus::InProgress);
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
@ -331,7 +328,7 @@ namespace storm {
// Update environment variables.
std::swap(currentX, newX);
++iterations;
status = updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee);
status = this->updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee);
// Potentially show progress.
this->showProgressIterative(iterations);
@ -417,7 +414,7 @@ namespace storm {
auto two = storm::utility::convertNumber<ValueType>(2.0);
storm::utility::vector::applyPointwise<ValueType, ValueType, ValueType>(*lowerX, *upperX, x, [&two] (ValueType const& a, ValueType const& b) -> ValueType { return (a + b) / two; });
reportStatus(statusIters.first, statusIters.second);
this->reportStatus(statusIters.first, statusIters.second);
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
@ -503,7 +500,7 @@ namespace storm {
std::swap(x, *currentX);
}
reportStatus(result.status, result.iterations);
this->reportStatus(result.status, result.iterations);
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
@ -666,17 +663,17 @@ namespace storm {
++iterations;
doConvergenceCheck = !doConvergenceCheck;
if (lowerStep) {
status = updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual);
status = this->updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual);
}
if (upperStep) {
status = updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual);
status = this->updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual);
}
// Potentially show progress.
this->showProgressIterative(iterations);
}
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// We take the means of the lower and upper bound so we guarantee the desired precision.
ValueType two = storm::utility::convertNumber<ValueType>(2.0);
@ -761,7 +758,7 @@ namespace storm {
this->A->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get());
}
reportStatus(status, iterations);
this->reportStatus(status, iterations);
if (!this->isCachingEnabled()) {
clearCache();
@ -1064,7 +1061,7 @@ namespace storm {
status = SolverStatus::MaximalIterationsExceeded;
}
reportStatus(status, overallIterations);
this->reportStatus(status, overallIterations);
return status == SolverStatus::Converged || status == SolverStatus::TerminatedEarly;
}
@ -1103,32 +1100,6 @@ namespace storm {
*choice = optimalRow - this->A->getRowGroupIndices()[group];
}
}
template<typename ValueType>
SolverStatus IterativeMinMaxLinearEquationSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
} else if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
return status;
}
template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::reportStatus(SolverStatus status, uint64_t iterations) {
switch (status) {
case SolverStatus::Converged: STORM_LOG_TRACE("Iterative solver converged after " << iterations << " iterations."); break;
case SolverStatus::TerminatedEarly: STORM_LOG_TRACE("Iterative solver terminated early after " << iterations << " iterations."); break;
case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
}
template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::clearCache() const {

2
src/storm/solver/IterativeMinMaxLinearEquationSolver.h

@ -86,8 +86,6 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector2; // A.rowGroupCount() entries
mutable std::unique_ptr<storm::solver::helper::SoundValueIterationHelper<ValueType>> soundValueIterationHelper;
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
static void reportStatus(SolverStatus status, uint64_t iterations);
};
}

130
src/storm/solver/NativeLinearEquationSolver.cpp

@ -65,19 +65,22 @@ namespace storm {
// Set up additional environment variables.
uint_fast64_t iterations = 0;
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
A->performSuccessiveOverRelaxationStep(omega, x, b);
// Now check if the process already converged within our precision.
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*this->cachedRowVector, x, precision, relative);
terminate = this->terminateNow(x, SolverGuarantee::None);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*this->cachedRowVector, x, precision, relative)) {
status = SolverStatus::Converged;
}
if (this->terminateNow(x, SolverGuarantee::None)) {
status = SolverStatus::TerminatedEarly;
}
// If we did not yet converge, we need to backup the contents of x.
if (!converged) {
if (status != SolverStatus::Converged) {
*this->cachedRowVector = x;
}
@ -86,15 +89,19 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow.
++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
if (!this->isCachingEnabled()) {
clearCache();
}
this->logIterations(converged, terminate, iterations);
return converged;
this->reportStatus(status, iterations);
return status == SolverStatus::Converged;
}
template<typename ValueType>
@ -127,20 +134,22 @@ namespace storm {
// Set up additional environment variables.
uint_fast64_t iterations = 0;
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Compute D^-1 * (b - LU * x) and store result in nextX.
jacobiDecomposition->multiplier->multiply(env, *currentX, nullptr, *nextX);
storm::utility::vector::subtractVectors(b, *nextX, *nextX);
storm::utility::vector::multiplyVectorsPointwise(jacobiDecomposition->DVector, *nextX, *nextX);
// Now check if the process already converged within our precision.
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *nextX, precision, relative);
terminate = this->terminateNow(*currentX, SolverGuarantee::None);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *nextX, precision, relative)) {
status = SolverStatus::Converged;
}
if (this->terminateNow(*currentX, SolverGuarantee::None)) {
status = SolverStatus::TerminatedEarly;
}
// Swap the two pointers as a preparation for the next iteration.
std::swap(nextX, currentX);
@ -149,6 +158,10 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow.
++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
// If the last iteration did not write to the original x we have to swap the contents, because the
@ -160,10 +173,10 @@ namespace storm {
if (!this->isCachingEnabled()) {
clearCache();
}
this->logIterations(converged, terminate, iterations);
return converged;
this->reportStatus(status, iterations);
return status == SolverStatus::Converged;
}
template<typename ValueType>
@ -257,10 +270,10 @@ namespace storm {
walkerChaeData->multiplier->multiply(env, *currentX, nullptr, currentAx);
// (3) Perform iterations until convergence.
bool converged = false;
SolverStatus status = SolverStatus::InProgress;
uint64_t iterations = 0;
this->startMeasureProgress();
while (!converged && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Perform one Walker-Chae step.
walkerChaeData->matrix.performWalkerChaeStep(*currentX, walkerChaeData->columnSums, walkerChaeData->b, currentAx, *nextX);
@ -268,7 +281,9 @@ namespace storm {
walkerChaeData->multiplier->multiply(env, *nextX, nullptr, currentAx);
// Check for convergence.
converged = storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound;
if (storm::utility::vector::computeSquaredNorm2Difference(currentAx, walkerChaeData->b) <= squaredErrorBound) {
status = SolverStatus::Converged;
}
// Swap the x vectors for the next iteration.
std::swap(currentX, nextX);
@ -278,6 +293,10 @@ namespace storm {
// Increase iteration count so we can abort if convergence is too slow.
++iterations;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
// If the last iteration did not write to the original x we have to swap the contents, because the
@ -296,13 +315,9 @@ namespace storm {
clearCache();
}
if (converged) {
STORM_LOG_INFO("Iterative solver converged in " << iterations << " iterations.");
} else {
STORM_LOG_WARN("Iterative solver did not converge in " << iterations << " iterations.");
}
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
}
template<typename ValueType>
@ -433,9 +448,8 @@ namespace storm {
if (!this->multiplier) {
this->multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, *A);
}
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
uint64_t iterations = 0;
bool doConvergenceCheck = true;
bool useDiffs = this->hasRelevantValues() && !env.solver().native().isSymmetricUpdatesSet();
@ -452,7 +466,7 @@ namespace storm {
}
uint64_t maxIter = env.solver().native().getMaximalNumberOfIterations();
this->startMeasureProgress();
while (!converged && !terminate && iterations < maxIter) {
while (status == SolverStatus::InProgress && iterations < maxIter) {
// Remember in which directions we took steps in this iteration.
bool lowerStep = false;
bool upperStep = false;
@ -537,24 +551,36 @@ namespace storm {
// precision here. Doing so, we need to take the means of the lower and upper values later to guarantee
// the original precision.
if (this->hasRelevantValues()) {
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, this->getRelevantValues(), precision, relative);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, this->getRelevantValues(), precision, relative)) {
status = SolverStatus::Converged;
}
} else {
converged = storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, precision, relative);
if (storm::utility::vector::equalModuloPrecision<ValueType>(*lowerX, *upperX, precision, relative)) {
status = SolverStatus::Converged;
}
}
if (lowerStep) {
terminate |= this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual);
if (this->terminateNow(*lowerX, SolverGuarantee::LessOrEqual)) {
status = SolverStatus::TerminatedEarly;
}
}
if (upperStep) {
terminate |= this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual);
if (this->terminateNow(*upperX, SolverGuarantee::GreaterOrEqual)) {
status = SolverStatus::TerminatedEarly;
}
}
}
// Potentially show progress.
this->showProgressIterative(iterations);
// Set up next iteration.
++iterations;
doConvergenceCheck = !doConvergenceCheck;
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
// We take the means of the lower and upper bound so we guarantee the desired precision.
@ -570,9 +596,9 @@ namespace storm {
if (!this->isCachingEnabled()) {
clearCache();
}
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
return converged;
return status == SolverStatus::Converged;
}
@ -603,35 +629,39 @@ namespace storm {
relevantValuesPtr = &this->getRelevantValues();
}
bool converged = false;
bool terminate = false;
SolverStatus status = SolverStatus::InProgress;
this->startMeasureProgress();
uint64_t iterations = 0;
while (!converged && iterations < env.solver().native().getMaximalNumberOfIterations()) {
while (status == SolverStatus::InProgress && iterations < env.solver().native().getMaximalNumberOfIterations()) {
this->soundValueIterationHelper->performIterationStep(b);
if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) {
converged = true;
status = SolverStatus::Converged;
}
// Check whether we terminate early.
terminate = this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition());
if (this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition())) {
status = SolverStatus::TerminatedEarly;
}
// Update environment variables.
++iterations;
// Potentially show progress.
this->showProgressIterative(iterations);
if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
this->soundValueIterationHelper->setSolutionVector();
this->logIterations(converged, terminate, iterations);
this->reportStatus(status, iterations);
if (!this->isCachingEnabled()) {
clearCache();
}
return converged;
return status == SolverStatus::Converged;
}
template<typename ValueType>
@ -961,7 +991,7 @@ namespace storm {
// Checked all values at this point.
return true;
}
template<typename ValueType>
void NativeLinearEquationSolver<ValueType>::logIterations(bool converged, bool terminate, uint64_t iterations) const {
if (converged) {
@ -972,7 +1002,7 @@ namespace storm {
STORM_LOG_WARN("Iterative solver did not converge in " << iterations << " iterations.");
}
}
template<typename ValueType>
NativeLinearEquationSolverMethod NativeLinearEquationSolver<ValueType>::getMethod(Environment const& env, bool isExactMode) const {
// Adjust the method if none was specified and we want exact or sound computations

16
src/storm/solver/StandardGameSolver.cpp

@ -230,7 +230,7 @@ namespace storm {
status = updateStatusIfNotConverged(status, x, iterations, maxIter);
} while (status == SolverStatus::InProgress);
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulersSet() && !(providedPlayer1Choices && providedPlayer2Choices)) {
@ -330,7 +330,7 @@ namespace storm {
status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter);
}
reportStatus(status, iterations);
this->reportStatus(status, iterations);
// If we performed an odd number of iterations, we need to swap the x and currentX, because the newest result
// is currently stored in currentX, but x is the output vector.
@ -574,18 +574,6 @@ namespace storm {
return status;
}
template<typename ValueType>
void StandardGameSolver<ValueType>::reportStatus(SolverStatus status, uint64_t iterations) const {
switch (status) {
case SolverStatus::Converged: STORM_LOG_INFO("Iterative solver converged after " << iterations << " iterations."); break;
case SolverStatus::TerminatedEarly: STORM_LOG_INFO("Iterative solver terminated early after " << iterations << " iterations."); break;
case SolverStatus::MaximalIterationsExceeded: STORM_LOG_WARN("Iterative solver did not converge after " << iterations << " iterations."); break;
case SolverStatus::Aborted: STORM_LOG_WARN("Iterative solver was aborted."); break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Iterative solver terminated unexpectedly.");
}
}
template<typename ValueType>
void StandardGameSolver<ValueType>::clearCache() const {
multiplierPlayer2Matrix.reset();

3
src/storm/solver/StandardGameSolver.h

@ -56,8 +56,7 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
void reportStatus(SolverStatus status, uint64_t iterations) const;
/// The factory used to obtain linear equation solvers.
std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory;

Loading…
Cancel
Save