Browse Source

Improved caching for svi

tempestpy_adaptions
TimQu 7 years ago
parent
commit
8b00f8441e
  1. 24
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  2. 2
      src/storm/solver/IterativeMinMaxLinearEquationSolver.h
  3. 21
      src/storm/solver/NativeLinearEquationSolver.cpp
  4. 2
      src/storm/solver/NativeLinearEquationSolver.h
  5. 1
      src/storm/solver/helper/SoundValueIterationHelper.cpp

24
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -10,7 +10,6 @@
#include "storm/utility/KwekMehlhorn.h"
#include "storm/utility/NumberTraits.h"
#include "storm/solver/helper/SoundValueIterationHelper.h"
#include "storm/utility/Stopwatch.h"
#include "storm/utility/vector.h"
#include "storm/utility/macros.h"
@ -608,21 +607,23 @@ namespace storm {
template<typename ValueType>
bool IterativeMinMaxLinearEquationSolver<ValueType>::solveEquationsSoundValueIteration(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
// Prepare the solution vectors.
// Prepare the solution vectors and the helper.
assert(x.size() == this->A->getRowGroupCount());
if (!this->auxiliaryRowGroupVector) {
this->auxiliaryRowGroupVector = std::make_unique<std::vector<ValueType>>();
}
// TODO: implement caching for the helper
storm::solver::helper::SoundValueIterationHelper<ValueType> helper(*this->A, x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()));
if (!this->soundValueIterationHelper) {
this->soundValueIterationHelper = std::make_unique<storm::solver::helper::SoundValueIterationHelper<ValueType>>(*this->A, x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()));
} else {
this->soundValueIterationHelper = std::make_unique<storm::solver::helper::SoundValueIterationHelper<ValueType>>(std::move(*this->soundValueIterationHelper), x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()));
}
// Prepare initial bounds for the solution (if given)
if (this->hasLowerBound()) {
helper.setLowerBound(this->getLowerBound(true));
this->soundValueIterationHelper->setLowerBound(this->getLowerBound(true));
}
if (this->hasUpperBound()) {
helper.setUpperBound(this->getUpperBound(true));
this->soundValueIterationHelper->setUpperBound(this->getUpperBound(true));
}
storm::storage::BitVector const* relevantValuesPtr = nullptr;
@ -635,8 +636,8 @@ namespace storm {
uint64_t iterations = 0;
while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) {
helper.performIterationStep(dir, b);
if (helper.checkConvergenceUpdateBounds(dir, relevantValuesPtr)) {
this->soundValueIterationHelper->performIterationStep(dir, b);
if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(dir, relevantValuesPtr)) {
status = SolverStatus::Converged;
}
@ -648,7 +649,7 @@ namespace storm {
// Potentially show progress.
this->showProgressIterative(iterations);
}
helper.setSolutionVector();
this->soundValueIterationHelper->setSolutionVector();
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
@ -1000,6 +1001,7 @@ namespace storm {
multiplierA.reset();
auxiliaryRowGroupVector.reset();
auxiliaryRowGroupVector2.reset();
soundValueIterationHelper.reset();
StandardMinMaxLinearEquationSolver<ValueType>::clearCache();
}

2
src/storm/solver/IterativeMinMaxLinearEquationSolver.h

@ -7,6 +7,7 @@
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/Multiplier.h"
#include "storm/solver/StandardMinMaxLinearEquationSolver.h"
#include "storm/solver/helper/SoundValueIterationHelper.h"
#include "storm/solver/SolverStatus.h"
@ -80,6 +81,7 @@ namespace storm {
mutable std::unique_ptr<storm::solver::Multiplier<ValueType>> multiplierA;
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector; // A.rowGroupCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector2; // A.rowGroupCount() entries
mutable std::unique_ptr<storm::solver::helper::SoundValueIterationHelper<ValueType>> soundValueIterationHelper;
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
static void reportStatus(SolverStatus status, uint64_t iterations);

21
src/storm/solver/NativeLinearEquationSolver.cpp

@ -568,21 +568,23 @@ namespace storm {
template<typename ValueType>
bool NativeLinearEquationSolver<ValueType>::solveEquationsSoundValueIteration(Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
// Prepare the solution vectors.
// Prepare the solution vectors and the helper.
assert(x.size() == this->A->getRowCount());
if (!this->cachedRowVector) {
this->cachedRowVector = std::make_unique<std::vector<ValueType>>();
}
// TODO: implement caching for the helper
storm::solver::helper::SoundValueIterationHelper<ValueType> helper(*this->A, x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().native().getPrecision()));
if (!this->soundValueIterationHelper) {
this->soundValueIterationHelper = std::make_unique<storm::solver::helper::SoundValueIterationHelper<ValueType>>(*this->A, x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().native().getPrecision()));
} else {
this->soundValueIterationHelper = std::make_unique<storm::solver::helper::SoundValueIterationHelper<ValueType>>(std::move(*this->soundValueIterationHelper), x, *this->cachedRowVector, env.solver().native().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().native().getPrecision()));
}
// Prepare initial bounds for the solution (if given)
if (this->hasLowerBound()) {
helper.setLowerBound(this->getLowerBound(true));
this->soundValueIterationHelper->setLowerBound(this->getLowerBound(true));
}
if (this->hasUpperBound()) {
helper.setUpperBound(this->getUpperBound(true));
this->soundValueIterationHelper->setUpperBound(this->getUpperBound(true));
}
storm::storage::BitVector const* relevantValuesPtr = nullptr;
@ -596,8 +598,8 @@ namespace storm {
uint64_t iterations = 0;
while (!converged && iterations < env.solver().native().getMaximalNumberOfIterations()) {
helper.performIterationStep(b);
if (helper.checkConvergenceUpdateBounds(relevantValuesPtr)) {
this->soundValueIterationHelper->performIterationStep(b);
if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(relevantValuesPtr)) {
converged = true;
}
@ -610,7 +612,7 @@ namespace storm {
// Potentially show progress.
this->showProgressIterative(iterations);
}
helper.setSolutionVector();
this->soundValueIterationHelper->setSolutionVector();
this->logIterations(converged, terminate, iterations);
@ -973,6 +975,7 @@ namespace storm {
cachedRowVector2.reset();
walkerChaeData.reset();
multiplier.reset();
soundValueIterationHelper.reset();
LinearEquationSolver<ValueType>::clearCache();
}

2
src/storm/solver/NativeLinearEquationSolver.h

@ -8,6 +8,7 @@
#include "storm/solver/SolverSelectionOptions.h"
#include "storm/solver/NativeMultiplier.h"
#include "storm/solver/SolverStatus.h"
#include "storm/solver/helper/SoundValueIterationHelper.h"
#include "storm/utility/NumberTraits.h"
@ -93,6 +94,7 @@ namespace storm {
// cached auxiliary data
mutable std::unique_ptr<std::vector<ValueType>> cachedRowVector2; // A.getRowCount() rows
mutable std::unique_ptr<storm::solver::helper::SoundValueIterationHelper<ValueType>> soundValueIterationHelper;
struct JacobiDecomposition {
JacobiDecomposition(Environment const& env, storm::storage::SparseMatrix<ValueType> const& A);

1
src/storm/solver/helper/SoundValueIterationHelper.cpp

@ -296,7 +296,6 @@ namespace storm {
<< ".");
}
template<typename ValueType>
bool SoundValueIterationHelper<ValueType>::checkConvergencePhase1() {

Loading…
Cancel
Save