Browse Source

General updateStatus function in AbstractEquationSolver

main
Matthias Volk 5 years ago
parent
commit
f50a7a424b
  1. 12
      src/storm/solver/AbstractEquationSolver.cpp
  2. 16
      src/storm/solver/AbstractEquationSolver.h
  3. 8
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  4. 20
      src/storm/solver/StandardGameSolver.cpp
  5. 2
      src/storm/solver/StandardGameSolver.h

12
src/storm/solver/AbstractEquationSolver.cpp

@ -317,9 +317,9 @@ namespace storm {
template<typename ValueType>
SolverStatus AbstractEquationSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const {
SolverStatus AbstractEquationSolver<ValueType>::updateStatus(SolverStatus status, bool earlyTermination, uint64_t iterations, uint64_t maximalNumberOfIterations) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee)) {
if (earlyTermination) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
@ -329,7 +329,13 @@ namespace storm {
}
return status;
}
template<typename ValueType>
SolverStatus AbstractEquationSolver<ValueType>::updateStatus(SolverStatus status, std::vector<ValueType> const& x, SolverGuarantee const& guarantee, uint64_t iterations, uint64_t maximalNumberOfIterations) const {
return this->updateStatus(status, this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x, guarantee), iterations, maximalNumberOfIterations);
}
template class AbstractEquationSolver<double>;
template class AbstractEquationSolver<float>;

16
src/storm/solver/AbstractEquationSolver.h

@ -203,13 +203,23 @@ namespace storm {
/*!
* Update the status of the solver with respect to convergence, early termination, abortion, etc.
* @param status Current status.
* @param x Vector x.
* @param x Vector for terminatation condition.
* @param guarantee Guarentee for termination condition.
* @param iterations Current number of iterations.
* @param maximalNumberOfIterations Maximal number of iterations.
* @param guarantee Guarantee.
* @return New status.
*/
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
SolverStatus updateStatus(SolverStatus status, std::vector<ValueType> const& x, SolverGuarantee const& guarantee, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
/*!
* Update the status of the solver with respect to convergence, early termination, abortion, etc.
* @param status Current status.
* @param earlyTermination Flag indicating if the solver can be terminated early.
* @param iterations Current number of iterations.
* @param maximalNumberOfIterations Maximal number of iterations.
* @return New status.
*/
SolverStatus updateStatus(SolverStatus status, bool earlyTermination, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
// A termination condition to be used (can be unset).
std::unique_ptr<TerminationCondition<ValueType>> terminationCondition;

8
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -199,7 +199,7 @@ namespace storm {
// Update environment variables.
++iterations;
status = this->updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual);
status = this->updateStatus(status, x, dir == storm::OptimizationDirection::Minimize ? SolverGuarantee::GreaterOrEqual : SolverGuarantee::LessOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations());
// Potentially show progress.
this->showProgressIterative(iterations);
@ -328,7 +328,7 @@ namespace storm {
// Update environment variables.
std::swap(currentX, newX);
++iterations;
status = this->updateStatusIfNotConverged(status, *currentX, iterations, maximalNumberOfIterations, guarantee);
status = this->updateStatus(status, *currentX, guarantee, iterations, maximalNumberOfIterations);
// Potentially show progress.
this->showProgressIterative(iterations);
@ -663,10 +663,10 @@ namespace storm {
++iterations;
doConvergenceCheck = !doConvergenceCheck;
if (lowerStep) {
status = this->updateStatusIfNotConverged(status, *lowerX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::LessOrEqual);
status = this->updateStatus(status, *lowerX, SolverGuarantee::LessOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations());
}
if (upperStep) {
status = this->updateStatusIfNotConverged(status, *upperX, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::GreaterOrEqual);
status = this->updateStatus(status, *upperX, SolverGuarantee::GreaterOrEqual, iterations, env.solver().minMax().getMaximalNumberOfIterations());
}
// Potentially show progress.

20
src/storm/solver/StandardGameSolver.cpp

@ -227,7 +227,7 @@ namespace storm {
// Update environment variables.
++iterations;
status = updateStatusIfNotConverged(status, x, iterations, maxIter);
status = this->updateStatus(status, x, SolverGuarantee::None, iterations, maxIter);
} while (status == SolverStatus::InProgress);
this->reportStatus(status, iterations);
@ -327,7 +327,7 @@ namespace storm {
// Update environment variables.
std::swap(currentX, newX);
++iterations;
status = updateStatusIfNotConverged(status, *currentX, iterations, maxIter);
status = this->updateStatus(status, *currentX, SolverGuarantee::None, iterations, maxIter);
}
this->reportStatus(status, iterations);
@ -559,21 +559,7 @@ namespace storm {
uint64_t StandardGameSolver<ValueType>::getNumberOfPlayer2States() const {
return this->player2Matrix.getRowGroupCount();
}
template<typename ValueType>
SolverStatus StandardGameSolver<ValueType>::updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const {
if (status != SolverStatus::Converged) {
if (this->hasCustomTerminationCondition() && this->getTerminationCondition().terminateNow(x)) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= maximalNumberOfIterations) {
status = SolverStatus::MaximalIterationsExceeded;
} else if (storm::utility::resources::isTerminate()) {
status = SolverStatus::Aborted;
}
}
return status;
}
template<typename ValueType>
void StandardGameSolver<ValueType>::clearCache() const {
multiplierPlayer2Matrix.reset();

2
src/storm/solver/StandardGameSolver.h

@ -55,8 +55,6 @@ namespace storm {
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations) const;
/// The factory used to obtain linear equation solvers.
std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory;

Loading…
Cancel
Save