Browse Source

Using multiplier in IterativeMinMaxSolvers

tempestpy_adaptions
TimQu 7 years ago
parent
commit
b7bac59ae0
  1. 150
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  2. 21
      src/storm/solver/IterativeMinMaxLinearEquationSolver.h

150
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -22,17 +22,17 @@ namespace storm {
namespace solver {
template<typename ValueType>
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver<ValueType>(std::move(linearEquationSolverFactory)) {
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : linearEquationSolverFactory(std::move(linearEquationSolverFactory)) {
// Intentionally left empty
}
template<typename ValueType>
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType> const& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver<ValueType>(A, std::move(linearEquationSolverFactory)) {
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType> const& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver<ValueType>(A), linearEquationSolverFactory(std::move(linearEquationSolverFactory)) {
// Intentionally left empty.
}
template<typename ValueType>
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType>&& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver<ValueType>(std::move(A), std::move(linearEquationSolverFactory)) {
IterativeMinMaxLinearEquationSolver<ValueType>::IterativeMinMaxLinearEquationSolver(storm::storage::SparseMatrix<ValueType>&& A, std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolver<ValueType>(std::move(A)), linearEquationSolverFactory(std::move(linearEquationSolverFactory)) {
// Intentionally left empty.
}
@ -221,12 +221,11 @@ namespace storm {
MinMaxLinearEquationSolverRequirements IterativeMinMaxLinearEquationSolver<ValueType>::getRequirements(Environment const& env, boost::optional<storm::solver::OptimizationDirection> const& direction, bool const& hasInitialScheduler) const {
auto method = getMethod(env, storm::NumberTraits<ValueType>::IsExact);
// Start by getting the requirements of the linear equation solver.
LinearEquationSolverTask linEqTask = LinearEquationSolverTask::Unspecified;
if ((method == MinMaxMethod::ValueIteration && !this->hasInitialScheduler() && !hasInitialScheduler) || method == MinMaxMethod::RationalSearch || method == MinMaxMethod::SoundValueIteration || method == MinMaxMethod::IntervalIteration) {
linEqTask = LinearEquationSolverTask::Multiply;
}
MinMaxLinearEquationSolverRequirements requirements(this->linearEquationSolverFactory->getRequirements(env, linEqTask));
// Check whether a linear equation solver is needed and potentially start with its requirements
bool needsLinEqSolver = false;
needsLinEqSolver |= method == MinMaxMethod::PolicyIteration;
needsLinEqSolver |= method == MinMaxMethod::ValueIteration && (this->hasInitialScheduler() || hasInitialScheduler);
MinMaxLinearEquationSolverRequirements requirements = needsLinEqSolver ? MinMaxLinearEquationSolverRequirements(this->linearEquationSolverFactory->getRequirements(env)) : MinMaxLinearEquationSolverRequirements();
if (method == MinMaxMethod::ValueIteration) {
if (!this->hasUniqueSolution()) { // Traditional value iteration has no requirements if the solution is unique.
@ -275,15 +274,15 @@ namespace storm {
}
template<typename ValueType>
typename IterativeMinMaxLinearEquationSolver<ValueType>::ValueIterationResult IterativeMinMaxLinearEquationSolver<ValueType>::performValueIteration(OptimizationDirection dir, std::vector<ValueType>*& currentX, std::vector<ValueType>*& newX, std::vector<ValueType> const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const {
typename IterativeMinMaxLinearEquationSolver<ValueType>::ValueIterationResult IterativeMinMaxLinearEquationSolver<ValueType>::performValueIteration(Environment const& env, OptimizationDirection dir, std::vector<ValueType>*& currentX, std::vector<ValueType>*& newX, std::vector<ValueType> const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const {
STORM_LOG_ASSERT(currentX != newX, "Vectors must not be aliased.");
// Get handle to linear equation solver.
storm::solver::LinearEquationSolver<ValueType> const& linearEquationSolver = *this->linEqSolverA;
// Get handle to multiplier.
storm::solver::Multiplier<ValueType> const& multiplier = *this->multiplierA;
// Allow aliased multiplications.
bool useGaussSeidelMultiplication = linearEquationSolver.supportsGaussSeidelMultiplication() && multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel;
bool useGaussSeidelMultiplication = multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel;
// Proceed with the iterations as long as the method did not converge or reach the maximum number of iterations.
uint64_t iterations = currentIterations;
@ -296,9 +295,9 @@ namespace storm {
if (useGaussSeidelMultiplication) {
// Copy over the current vector so we can modify it in-place.
*newX = *currentX;
linearEquationSolver.multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *newX, &b);
multiplier.multiplyAndReduceGaussSeidel(env, dir, *newX, &b);
} else {
linearEquationSolver.multiplyAndReduce(dir, this->A->getRowGroupIndices(), *currentX, &b, *newX);
multiplier.multiplyAndReduce(env, dir, *currentX, &b, *newX);
}
// Determine whether the method converged.
@ -325,9 +324,8 @@ namespace storm {
template<typename ValueType>
bool IterativeMinMaxLinearEquationSolver<ValueType>::solveEquationsValueIteration(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
if (!this->linEqSolverA) {
this->createLinearEquationSolver(env);
this->linEqSolverA->setCachingEnabled(true);
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
if (!auxiliaryRowGroupVector) {
@ -387,7 +385,7 @@ namespace storm {
std::vector<ValueType>* currentX = &x;
this->startMeasureProgress();
ValueIterationResult result = performValueIteration(dir, currentX, newX, b, storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()), env.solver().minMax().getRelativeTerminationCriterion(), guarantee, 0, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle());
ValueIterationResult result = performValueIteration(env, dir, currentX, newX, b, storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()), env.solver().minMax().getRelativeTerminationCriterion(), guarantee, 0, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle());
// Swap the result into the output x.
if (currentX == auxiliaryRowGroupVector.get()) {
@ -399,7 +397,7 @@ namespace storm {
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
this->schedulerChoices = std::vector<uint_fast64_t>(this->A->getRowGroupCount());
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *auxiliaryRowGroupVector.get(), &this->schedulerChoices.get());
this->multiplierA->multiplyAndReduce(env, dir, x, &b, *auxiliaryRowGroupVector.get(), &this->schedulerChoices.get());
}
if (!this->isCachingEnabled()) {
@ -443,9 +441,8 @@ namespace storm {
bool IterativeMinMaxLinearEquationSolver<ValueType>::solveEquationsIntervalIteration(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
STORM_LOG_THROW(this->hasUpperBound(), storm::exceptions::UnmetRequirementException, "Solver requires upper bound, but none was given.");
if (!this->linEqSolverA) {
this->createLinearEquationSolver(env);
this->linEqSolverA->setCachingEnabled(true);
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
if (!auxiliaryRowGroupVector) {
@ -453,7 +450,7 @@ namespace storm {
}
// Allow aliased multiplications.
bool useGaussSeidelMultiplication = this->linEqSolverA->supportsGaussSeidelMultiplication() && env.solver().minMax().getMultiplicationStyle() == storm::solver::MultiplicationStyle::GaussSeidel;
bool useGaussSeidelMultiplication = env.solver().minMax().getMultiplicationStyle() == storm::solver::MultiplicationStyle::GaussSeidel;
std::vector<ValueType>* lowerX = &x;
this->createLowerBoundsVector(*lowerX);
@ -497,22 +494,22 @@ namespace storm {
if (useDiffs) {
preserveOldRelevantValues(*lowerX, this->getRelevantValues(), oldValues);
}
this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *lowerX, &b);
this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *lowerX, &b);
if (useDiffs) {
maxLowerDiff = computeMaxAbsDiff(*lowerX, this->getRelevantValues(), oldValues);
preserveOldRelevantValues(*upperX, this->getRelevantValues(), oldValues);
}
this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *upperX, &b);
this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *upperX, &b);
if (useDiffs) {
maxUpperDiff = computeMaxAbsDiff(*upperX, this->getRelevantValues(), oldValues);
}
} else {
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *lowerX, &b, *tmp);
this->multiplierA->multiplyAndReduce(env, dir, *lowerX, &b, *tmp);
if (useDiffs) {
maxLowerDiff = computeMaxAbsDiff(*lowerX, *tmp, this->getRelevantValues());
}
std::swap(lowerX, tmp);
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *upperX, &b, *tmp);
this->multiplierA->multiplyAndReduce(env, dir, *upperX, &b, *tmp);
if (useDiffs) {
maxUpperDiff = computeMaxAbsDiff(*upperX, *tmp, this->getRelevantValues());
}
@ -525,7 +522,7 @@ namespace storm {
if (useDiffs) {
preserveOldRelevantValues(*lowerX, this->getRelevantValues(), oldValues);
}
this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *lowerX, &b);
this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *lowerX, &b);
if (useDiffs) {
maxLowerDiff = computeMaxAbsDiff(*lowerX, this->getRelevantValues(), oldValues);
}
@ -534,7 +531,7 @@ namespace storm {
if (useDiffs) {
preserveOldRelevantValues(*upperX, this->getRelevantValues(), oldValues);
}
this->linEqSolverA->multiplyAndReduceGaussSeidel(dir, this->A->getRowGroupIndices(), *upperX, &b);
this->multiplierA->multiplyAndReduceGaussSeidel(env, dir, *upperX, &b);
if (useDiffs) {
maxUpperDiff = computeMaxAbsDiff(*upperX, this->getRelevantValues(), oldValues);
}
@ -542,14 +539,14 @@ namespace storm {
}
} else {
if (maxLowerDiff >= maxUpperDiff) {
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *lowerX, &b, *tmp);
this->multiplierA->multiplyAndReduce(env, dir, *lowerX, &b, *tmp);
if (useDiffs) {
maxLowerDiff = computeMaxAbsDiff(*lowerX, *tmp, this->getRelevantValues());
}
std::swap(tmp, lowerX);
lowerStep = true;
} else {
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), *upperX, &b, *tmp);
this->multiplierA->multiplyAndReduce(env, dir, *upperX, &b, *tmp);
if (useDiffs) {
maxUpperDiff = computeMaxAbsDiff(*upperX, *tmp, this->getRelevantValues());
}
@ -604,7 +601,7 @@ namespace storm {
// If requested, we store the scheduler for retrieval.
if (this->isTrackSchedulerSet()) {
this->schedulerChoices = std::vector<uint_fast64_t>(this->A->getRowGroupCount());
this->linEqSolverA->multiplyAndReduce(dir, this->A->getRowGroupIndices(), x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get());
this->multiplierA->multiplyAndReduce(env, dir, x, &b, *this->auxiliaryRowGroupVector, &this->schedulerChoices.get());
}
if (!this->isCachingEnabled()) {
@ -669,19 +666,24 @@ namespace storm {
return maximize(dir) ? minIndex : maxIndex;
}
void multiplyRow(uint64_t const& row, storm::storage::SparseMatrix<ValueType> const& A, ValueType const& bi, ValueType& xi, ValueType& yi) {
void multiplyRow(uint64_t const& row, storm::storage::SparseMatrix<ValueType> const& A, storm::solver::Multiplier<ValueType> const& multiplier, ValueType const& bi, ValueType& xi, ValueType& yi) {
xi = multiplier.multiplyRow(row, x, bi);
yi = multiplier.multiplyRow(row, y, storm::utility::zero<ValueType>());
/*
xi = bi;
yi = storm::utility::zero<ValueType>();
for (auto const& entry : A.getRow(row)) {
xi += entry.getValue() * x[entry.getColumn()];
yi += entry.getValue() * y[entry.getColumn()];
}
*/
}
template<OptimizationDirection dir>
void performIterationStep(storm::storage::SparseMatrix<ValueType> const& A, std::vector<ValueType> const& b) {
void performIterationStep(storm::storage::SparseMatrix<ValueType> const& A, storm::solver::Multiplier<ValueType> const& multiplier, std::vector<ValueType> const& b) {
if (!decisionValueBlocks) {
performIterationStepUpdateDecisionValue<dir>(A, b);
performIterationStepUpdateDecisionValue<dir>(A, multiplier, b);
} else {
assert(decisionValue == getPrimaryBound<dir>());
auto xIt = x.rbegin();
@ -693,7 +695,7 @@ namespace storm {
// Perform the iteration for the first row in the group
uint64_t row = *groupStartIt;
ValueType xBest, yBest;
multiplyRow(row, A, b[row], xBest, yBest);
multiplyRow(row, A, multiplier, b[row], xBest, yBest);
++row;
// Only do more work if there are still rows in this row group
if (row != groupEnd) {
@ -701,7 +703,7 @@ namespace storm {
ValueType bestValue = xBest + yBest * getPrimaryBound<dir>();
for (;row < groupEnd; ++row) {
// Get the multiplication results
multiplyRow(row, A, b[row], xi, yi);
multiplyRow(row, A, multiplier, b[row], xi, yi);
ValueType currentValue = xi + yi * getPrimaryBound<dir>();
// Check if the current row is better then the previously found one
if (better<dir>(currentValue, bestValue)) {
@ -722,7 +724,7 @@ namespace storm {
}
template<OptimizationDirection dir>
void performIterationStepUpdateDecisionValue(storm::storage::SparseMatrix<ValueType> const& A, std::vector<ValueType> const& b) {
void performIterationStepUpdateDecisionValue(storm::storage::SparseMatrix<ValueType> const& A, storm::solver::Multiplier<ValueType> const& multiplier, std::vector<ValueType> const& b) {
auto xIt = x.rbegin();
auto yIt = y.rbegin();
auto groupStartIt = A.getRowGroupIndices().rbegin();
@ -732,7 +734,7 @@ namespace storm {
// Perform the iteration for the first row in the group
uint64_t row = *groupStartIt;
ValueType xBest, yBest;
multiplyRow(row, A, b[row], xBest, yBest);
multiplyRow(row, A, multiplier, b[row], xBest, yBest);
++row;
// Only do more work if there are still rows in this row group
if (row != groupEnd) {
@ -742,7 +744,7 @@ namespace storm {
ValueType bestValue = xBest + yBest * getPrimaryBound<dir>();
for (;row < groupEnd; ++row) {
// Get the multiplication results
multiplyRow(row, A, b[row], xi, yi);
multiplyRow(row, A, multiplier, b[row], xi, yi);
ValueType currentValue = xi + yi * getPrimaryBound<dir>();
// Check if the current row is better then the previously found one
if (better<dir>(currentValue, bestValue)) {
@ -769,7 +771,7 @@ namespace storm {
}
} else {
for (;row < groupEnd; ++row) {
multiplyRow(row, A, b[row], xi, yi);
multiplyRow(row, A, multiplier, b[row], xi, yi);
// Update the best choice
if (yi > yBest || (yi == yBest && better<dir>(xi, xBest))) {
xTmp[xyTmpIndex] = std::move(xBest);
@ -978,6 +980,10 @@ namespace storm {
this->auxiliaryRowGroupVector = std::make_unique<std::vector<ValueType>>();
}
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
SoundValueIterationHelper<ValueType> helper(x, *this->auxiliaryRowGroupVector, env.solver().minMax().getRelativeTerminationCriterion(), storm::utility::convertNumber<ValueType>(env.solver().minMax().getPrecision()), this->A->getSizeOfLargestRowGroup());
// Prepare initial bounds for the solution (if given)
@ -999,13 +1005,13 @@ namespace storm {
while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) {
if (minimize(dir)) {
helper.template performIterationStep<OptimizationDirection::Minimize>(*this->A, b);
helper.template performIterationStep<OptimizationDirection::Minimize>(*this->A, *this->multiplierA, b);
if (helper.template checkConvergenceUpdateBounds<OptimizationDirection::Minimize>(relevantValuesPtr)) {
status = SolverStatus::Converged;
}
} else {
assert(maximize(dir));
helper.template performIterationStep<OptimizationDirection::Maximize>(*this->A, b);
helper.template performIterationStep<OptimizationDirection::Maximize>(*this->A, *this->multiplierA, b);
if (helper.template checkConvergenceUpdateBounds<OptimizationDirection::Maximize>(relevantValuesPtr)) {
status = SolverStatus::Converged;
}
@ -1085,11 +1091,6 @@ namespace storm {
return false;
}
template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::createLinearEquationSolver(Environment const& env) const {
this->linEqSolverA = this->linearEquationSolverFactory->create(env, *this->A, LinearEquationSolverTask::Multiply);
}
template<typename ValueType>
template<typename ImpreciseType>
typename std::enable_if<std::is_same<ValueType, ImpreciseType>::value && !NumberTraits<ValueType>::IsExact, bool>::type IterativeMinMaxLinearEquationSolver<ValueType>::solveEquationsRationalSearchHelper(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
@ -1100,9 +1101,8 @@ namespace storm {
std::vector<storm::RationalNumber> rationalX(x.size());
std::vector<storm::RationalNumber> rationalB = storm::utility::vector::convertNumericVector<storm::RationalNumber>(b);
if (!this->linEqSolverA) {
this->createLinearEquationSolver(env);
this->linEqSolverA->setCachingEnabled(true);
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
if (!auxiliaryRowGroupVector) {
@ -1130,9 +1130,8 @@ namespace storm {
typename std::enable_if<std::is_same<ValueType, ImpreciseType>::value && NumberTraits<ValueType>::IsExact, bool>::type IterativeMinMaxLinearEquationSolver<ValueType>::solveEquationsRationalSearchHelper(Environment const& env, OptimizationDirection dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
// Version for when the overall value type is exact and the same type is to be used for the imprecise part.
if (!this->linEqSolverA) {
this->createLinearEquationSolver(env);
this->linEqSolverA->setCachingEnabled(true);
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
if (!auxiliaryRowGroupVector) {
@ -1182,13 +1181,14 @@ namespace storm {
// Create imprecise solver from the imprecise data.
IterativeMinMaxLinearEquationSolver<ImpreciseType> impreciseSolver(std::make_unique<storm::solver::GeneralLinearEquationSolverFactory<ImpreciseType>>());
impreciseSolver.setMatrix(impreciseA);
impreciseSolver.createLinearEquationSolver(env);
impreciseSolver.setCachingEnabled(true);
impreciseSolver.multiplierA = storm::solver::MultiplierFactory<ImpreciseType>().create(env, impreciseA);
bool converged = false;
try {
// Forward the call to the core rational search routine.
converged = solveEquationsRationalSearchHelper<ValueType, ImpreciseType>(env, dir, impreciseSolver, *this->A, x, b, impreciseA, impreciseX, impreciseB, impreciseTmpX);
impreciseSolver.clearCache();
} catch (storm::exceptions::PrecisionExceededException const& e) {
STORM_LOG_WARN("Precision of value type was exceeded, trying to recover by switching to rational arithmetic.");
@ -1208,9 +1208,8 @@ namespace storm {
impreciseB = std::vector<ImpreciseType>();
impreciseA = storm::storage::SparseMatrix<ImpreciseType>();
if (!this->linEqSolverA) {
createLinearEquationSolver(env);
this->linEqSolverA->setCachingEnabled(true);
if (!this->multiplierA) {
this->multiplierA = storm::solver::MultiplierFactory<ValueType>().create(env, *this->A);
}
// Forward the call to the core rational search routine, but now with our value type as the imprecise value type.
@ -1270,7 +1269,7 @@ namespace storm {
impreciseSolver.startMeasureProgress();
while (status == SolverStatus::InProgress && overallIterations < env.solver().minMax().getMaximalNumberOfIterations()) {
// Perform value iteration with the current precision.
typename IterativeMinMaxLinearEquationSolver<ImpreciseType>::ValueIterationResult result = impreciseSolver.performValueIteration(dir, currentX, newX, b, storm::utility::convertNumber<ImpreciseType, ValueType>(precision), env.solver().minMax().getRelativeTerminationCriterion(), SolverGuarantee::LessOrEqual, overallIterations, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle());
typename IterativeMinMaxLinearEquationSolver<ImpreciseType>::ValueIterationResult result = impreciseSolver.performValueIteration(env, dir, currentX, newX, b, storm::utility::convertNumber<ImpreciseType, ValueType>(precision), env.solver().minMax().getRelativeTerminationCriterion(), SolverGuarantee::LessOrEqual, overallIterations, env.solver().minMax().getMaximalNumberOfIterations(), env.solver().minMax().getMultiplicationStyle());
// At this point, the result of the imprecise value iteration is stored in the (imprecise) current x.
@ -1370,45 +1369,16 @@ namespace storm {
template<typename ValueType>
void IterativeMinMaxLinearEquationSolver<ValueType>::clearCache() const {
multiplierA.reset();
auxiliaryRowGroupVector.reset();
auxiliaryRowGroupVector2.reset();
rowGroupOrdering.reset();
StandardMinMaxLinearEquationSolver<ValueType>::clearCache();
}
template<typename ValueType>
IterativeMinMaxLinearEquationSolverFactory<ValueType>::IterativeMinMaxLinearEquationSolverFactory() : StandardMinMaxLinearEquationSolverFactory<ValueType>() {
// Intentionally left empty
}
template<typename ValueType>
IterativeMinMaxLinearEquationSolverFactory<ValueType>::IterativeMinMaxLinearEquationSolverFactory(std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory) : StandardMinMaxLinearEquationSolverFactory<ValueType>(std::move(linearEquationSolverFactory)) {
// Intentionally left empty
}
template<typename ValueType>
IterativeMinMaxLinearEquationSolverFactory<ValueType>::IterativeMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType) : StandardMinMaxLinearEquationSolverFactory<ValueType>(solverType) {
// Intentionally left empty
}
template<typename ValueType>
std::unique_ptr<MinMaxLinearEquationSolver<ValueType>> IterativeMinMaxLinearEquationSolverFactory<ValueType>::create(Environment const& env) const {
STORM_LOG_ASSERT(this->linearEquationSolverFactory, "Linear equation solver factory not initialized.");
auto method = env.solver().minMax().getMethod();
STORM_LOG_THROW(method == MinMaxMethod::ValueIteration || method == MinMaxMethod::PolicyIteration || method == MinMaxMethod::RationalSearch || method == MinMaxMethod::IntervalIteration || method == MinMaxMethod::SoundValueIteration, storm::exceptions::InvalidEnvironmentException, "This solver does not support the selected method.");
std::unique_ptr<MinMaxLinearEquationSolver<ValueType>> result = std::make_unique<IterativeMinMaxLinearEquationSolver<ValueType>>(this->linearEquationSolverFactory->clone());
result->setRequirementsChecked(this->isRequirementsCheckedSet());
return result;
}
template class IterativeMinMaxLinearEquationSolver<double>;
template class IterativeMinMaxLinearEquationSolverFactory<double>;
#ifdef STORM_HAVE_CARL
template class IterativeMinMaxLinearEquationSolver<storm::RationalNumber>;
template class IterativeMinMaxLinearEquationSolverFactory<storm::RationalNumber>;
#endif
}
}

21
src/storm/solver/IterativeMinMaxLinearEquationSolver.h

@ -5,6 +5,7 @@
#include "storm/utility/NumberTraits.h"
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/Multiplier.h"
#include "storm/solver/StandardMinMaxLinearEquationSolver.h"
#include "storm/solver/SolverStatus.h"
@ -68,31 +69,21 @@ namespace storm {
template <typename ValueTypePrime>
friend class IterativeMinMaxLinearEquationSolver;
ValueIterationResult performValueIteration(OptimizationDirection dir, std::vector<ValueType>*& currentX, std::vector<ValueType>*& newX, std::vector<ValueType> const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const;
ValueIterationResult performValueIteration(Environment const& env, OptimizationDirection dir, std::vector<ValueType>*& currentX, std::vector<ValueType>*& newX, std::vector<ValueType> const& b, ValueType const& precision, bool relative, SolverGuarantee const& guarantee, uint64_t currentIterations, uint64_t maximalNumberOfIterations, storm::solver::MultiplicationStyle const& multiplicationStyle) const;
void createLinearEquationSolver(Environment const& env) const;
/// The factory used to obtain linear equation solvers.
std::unique_ptr<LinearEquationSolverFactory<ValueType>> linearEquationSolverFactory;
// possibly cached data
mutable std::unique_ptr<storm::solver::Multiplier<ValueType>> multiplierA;
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector; // A.rowGroupCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryRowGroupVector2; // A.rowGroupCount() entries
mutable std::unique_ptr<std::vector<uint64_t>> rowGroupOrdering; // A.rowGroupCount() entries
SolverStatus updateStatusIfNotConverged(SolverStatus status, std::vector<ValueType> const& x, uint64_t iterations, uint64_t maximalNumberOfIterations, SolverGuarantee const& guarantee) const;
static void reportStatus(SolverStatus status, uint64_t iterations);
};
template<typename ValueType>
class IterativeMinMaxLinearEquationSolverFactory : public StandardMinMaxLinearEquationSolverFactory<ValueType> {
public:
IterativeMinMaxLinearEquationSolverFactory();
IterativeMinMaxLinearEquationSolverFactory(std::unique_ptr<LinearEquationSolverFactory<ValueType>>&& linearEquationSolverFactory);
IterativeMinMaxLinearEquationSolverFactory(EquationSolverType const& solverType);
// Make the other create methods visible.
using MinMaxLinearEquationSolverFactory<ValueType>::create;
virtual std::unique_ptr<MinMaxLinearEquationSolver<ValueType>> create(Environment const& env) const override;
};
}
}
Loading…
Cancel
Save