From 383e2172d4efc9edddab9b459a651ba5641fe3c6 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Wed, 4 Mar 2020 20:59:53 +0100 Subject: [PATCH] Added OVI for linear equation systems (i.e. DTMC/CTMC) --- .../modules/NativeEquationSolverSettings.cpp | 4 +- .../solver/NativeLinearEquationSolver.cpp | 72 ++++++++++++++++++- src/storm/solver/NativeLinearEquationSolver.h | 1 + src/storm/solver/SolverSelectionOptions.cpp | 2 + src/storm/solver/SolverSelectionOptions.h | 2 +- 5 files changed, 76 insertions(+), 5 deletions(-) diff --git a/src/storm/settings/modules/NativeEquationSolverSettings.cpp b/src/storm/settings/modules/NativeEquationSolverSettings.cpp index 0180b3b24..4dd83f075 100644 --- a/src/storm/settings/modules/NativeEquationSolverSettings.cpp +++ b/src/storm/settings/modules/NativeEquationSolverSettings.cpp @@ -26,7 +26,7 @@ namespace storm { const std::string NativeEquationSolverSettings::intervalIterationSymmetricUpdatesOptionName = "symmetricupdates"; NativeEquationSolverSettings::NativeEquationSolverSettings() : ModuleSettings(moduleName) { - std::vector methods = { "jacobi", "gaussseidel", "sor", "walkerchae", "power", "sound-value-iteration", "svi", "interval-iteration", "ii", "ratsearch" }; + std::vector methods = { "jacobi", "gaussseidel", "sor", "walkerchae", "power", "sound-value-iteration", "svi", "optimistic-value-itearation", "ovi", "interval-iteration", "ii", "ratsearch" }; this->addOption(storm::settings::OptionBuilder(moduleName, techniqueOptionName, true, "The method to be used for solving linear equation systems with the native engine.").setIsAdvanced().addArgument(storm::settings::ArgumentBuilder::createStringArgument("name", "The name of the method to use.").addValidatorString(ArgumentValidatorFactory::createMultipleChoiceValidator(methods)).setDefaultValueString("jacobi").build()).build()); this->addOption(storm::settings::OptionBuilder(moduleName, maximalIterationsOptionName, false, "The maximal number of iterations to perform before iterative solving is aborted.").setIsAdvanced().setShortName(maximalIterationsOptionShortName).addArgument(storm::settings::ArgumentBuilder::createUnsignedIntegerArgument("count", "The maximal iteration count.").build()).build()); @@ -66,6 +66,8 @@ namespace storm { return storm::solver::NativeLinearEquationSolverMethod::Power; } else if (linearEquationSystemTechniqueAsString == "sound-value-iteration" || linearEquationSystemTechniqueAsString == "svi") { return storm::solver::NativeLinearEquationSolverMethod::SoundValueIteration; + } else if (linearEquationSystemTechniqueAsString == "optimistic-value-iteration" || linearEquationSystemTechniqueAsString == "ovi") { + return storm::solver::NativeLinearEquationSolverMethod::OptimisticValueIteration; } else if (linearEquationSystemTechniqueAsString == "interval-iteration" || linearEquationSystemTechniqueAsString == "ii") { return storm::solver::NativeLinearEquationSolverMethod::IntervalIteration; } else if (linearEquationSystemTechniqueAsString == "ratsearch") { diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index a79bc6b3e..da2ba626d 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -10,6 +10,7 @@ #include "storm/utility/constants.h" #include "storm/utility/vector.h" #include "storm/solver/helper/SoundValueIterationHelper.h" +#include "storm/solver/helper/OptimisticValueIterationHelper.h" #include "storm/solver/Multiplier.h" #include "storm/exceptions/InvalidStateException.h" #include "storm/exceptions/InvalidEnvironmentException.h" @@ -623,6 +624,69 @@ namespace storm { return converged; } + template + bool NativeLinearEquationSolver::solveEquationsOptimisticValueIteration(Environment const& env, std::vector& x, std::vector const& b) const { + + if (!this->multiplier) { + this->multiplier = storm::solver::MultiplierFactory().create(env, *this->A); + } + + if (!this->cachedRowVector) { + this->cachedRowVector = std::make_unique>(this->A->getRowCount()); + } + if (!this->cachedRowVector2) { + this->cachedRowVector2 = std::make_unique>(this->A->getRowCount()); + } + + // By default, we can not provide any guarantee + SolverGuarantee guarantee = SolverGuarantee::None; + // Get handle to multiplier. + storm::solver::Multiplier const &multiplier = *this->multiplier; + // Allow aliased multiplications. + storm::solver::MultiplicationStyle multiplicationStyle = env.solver().native().getPowerMethodMultiplicationStyle(); + bool useGaussSeidelMultiplication = multiplicationStyle == storm::solver::MultiplicationStyle::GaussSeidel; + + boost::optional relevantValues; + if (this->hasRelevantValues()) { + relevantValues = this->getRelevantValues(); + } + + // x has to start with a lower bound. + this->createLowerBoundsVector(x); + + std::vector* lowerX = &x; + std::vector* upperX = this->cachedRowVector.get(); + std::vector* auxVector = this->cachedRowVector2.get(); + + this->startMeasureProgress(); + + auto statusIters = storm::solver::helper::solveEquationsOptimisticValueIteration(env, lowerX, upperX, auxVector, + [&] (std::vector*& y, std::vector*& yPrime, ValueType const& precision, bool const& relative, uint64_t const& i, uint64_t const& maxI) { + this->showProgressIterative(i); + return performPowerIteration(env, y, yPrime, b, precision, relative, guarantee, i, maxI, multiplicationStyle); + }, + [&] (std::vector* y, std::vector* yPrime, uint64_t const& i) { + this->showProgressIterative(i); + if (useGaussSeidelMultiplication) { + // Copy over the current vectors so we can modify them in-place. + // This is necessary as we want to compare the new values with the current ones. + *yPrime = *y; + multiplier.multiplyGaussSeidel(env, *y, &b); + } else { + multiplier.multiply(env, *y, &b, *yPrime); + std::swap(y, yPrime); + } + }, relevantValues); + auto two = storm::utility::convertNumber(2.0); + storm::utility::vector::applyPointwise(*lowerX, *upperX, x, [&two] (ValueType const& a, ValueType const& b) -> ValueType { return (a + b) / two; }); + this->logIterations(statusIters.first == SolverStatus::Converged, statusIters.first == SolverStatus::TerminatedEarly, statusIters.second); + + if (!this->isCachingEnabled()) { + clearCache(); + } + return statusIters.first == SolverStatus::Converged || statusIters.first == SolverStatus::TerminatedEarly; + } + template bool NativeLinearEquationSolver::solveEquationsRationalSearch(Environment const& env, std::vector& x, std::vector const& b) const { return solveEquationsRationalSearchHelper(env, x, b); @@ -908,7 +972,7 @@ namespace storm { } else { STORM_LOG_WARN("The selected solution method does not guarantee exact results."); } - } else if (env.solver().isForceSoundness() && method != NativeLinearEquationSolverMethod::SoundValueIteration && method != NativeLinearEquationSolverMethod::IntervalIteration && method != NativeLinearEquationSolverMethod::RationalSearch) { + } else if (env.solver().isForceSoundness() && method != NativeLinearEquationSolverMethod::SoundValueIteration && method != NativeLinearEquationSolverMethod::OptimisticValueIteration && method != NativeLinearEquationSolverMethod::IntervalIteration && method != NativeLinearEquationSolverMethod::RationalSearch) { if (env.solver().native().isMethodSetFromDefault()) { method = NativeLinearEquationSolverMethod::SoundValueIteration; STORM_LOG_INFO("Selecting '" + toString(method) + "' as the solution technique to guarantee sound results. If you want to override this, please explicitly specify a different method."); @@ -935,6 +999,8 @@ namespace storm { return this->solveEquationsPower(env, x, b); case NativeLinearEquationSolverMethod::SoundValueIteration: return this->solveEquationsSoundValueIteration(env, x, b); + case NativeLinearEquationSolverMethod::OptimisticValueIteration: + return this->solveEquationsOptimisticValueIteration(env, x, b); case NativeLinearEquationSolverMethod::IntervalIteration: return this->solveEquationsIntervalIteration(env, x, b); case NativeLinearEquationSolverMethod::RationalSearch: @@ -947,7 +1013,7 @@ namespace storm { template LinearEquationSolverProblemFormat NativeLinearEquationSolver::getEquationProblemFormat(Environment const& env) const { auto method = getMethod(env, storm::NumberTraits::IsExact || env.solver().isForceExact()); - if (method == NativeLinearEquationSolverMethod::Power || method == NativeLinearEquationSolverMethod::SoundValueIteration || method == NativeLinearEquationSolverMethod::RationalSearch || method == NativeLinearEquationSolverMethod::IntervalIteration) { + if (method == NativeLinearEquationSolverMethod::Power || method == NativeLinearEquationSolverMethod::SoundValueIteration || method == NativeLinearEquationSolverMethod::OptimisticValueIteration || method == NativeLinearEquationSolverMethod::RationalSearch || method == NativeLinearEquationSolverMethod::IntervalIteration) { return LinearEquationSolverProblemFormat::FixedPointSystem; } else { return LinearEquationSolverProblemFormat::EquationSystem; @@ -960,7 +1026,7 @@ namespace storm { auto method = getMethod(env, storm::NumberTraits::IsExact || env.solver().isForceExact()); if (method == NativeLinearEquationSolverMethod::IntervalIteration) { requirements.requireBounds(); - } else if (method == NativeLinearEquationSolverMethod::RationalSearch) { + } else if (method == NativeLinearEquationSolverMethod::RationalSearch || method == NativeLinearEquationSolverMethod::OptimisticValueIteration) { requirements.requireLowerBounds(); } else if (method == NativeLinearEquationSolverMethod::SoundValueIteration) { requirements.requireBounds(false); diff --git a/src/storm/solver/NativeLinearEquationSolver.h b/src/storm/solver/NativeLinearEquationSolver.h index 0cc493535..d551cd931 100644 --- a/src/storm/solver/NativeLinearEquationSolver.h +++ b/src/storm/solver/NativeLinearEquationSolver.h @@ -66,6 +66,7 @@ namespace storm { virtual bool solveEquationsWalkerChae(storm::Environment const& env, std::vector& x, std::vector const& b) const; virtual bool solveEquationsPower(storm::Environment const& env, std::vector& x, std::vector const& b) const; virtual bool solveEquationsSoundValueIteration(storm::Environment const& env, std::vector& x, std::vector const& b) const; + virtual bool solveEquationsOptimisticValueIteration(storm::Environment const& env, std::vector& x, std::vector const& b) const; virtual bool solveEquationsIntervalIteration(storm::Environment const& env, std::vector& x, std::vector const& b) const; virtual bool solveEquationsRationalSearch(storm::Environment const& env, std::vector& x, std::vector const& b) const; diff --git a/src/storm/solver/SolverSelectionOptions.cpp b/src/storm/solver/SolverSelectionOptions.cpp index 9e69c09a4..12e1cac02 100644 --- a/src/storm/solver/SolverSelectionOptions.cpp +++ b/src/storm/solver/SolverSelectionOptions.cpp @@ -124,6 +124,8 @@ namespace storm { return "Power"; case NativeLinearEquationSolverMethod::SoundValueIteration: return "SoundValueIteration"; + case NativeLinearEquationSolverMethod::OptimisticValueIteration: + return "optimisticvalueiteration"; case NativeLinearEquationSolverMethod::IntervalIteration: return "IntervalIteration"; case NativeLinearEquationSolverMethod::RationalSearch: diff --git a/src/storm/solver/SolverSelectionOptions.h b/src/storm/solver/SolverSelectionOptions.h index 6bb1300fb..7c2c90693 100644 --- a/src/storm/solver/SolverSelectionOptions.h +++ b/src/storm/solver/SolverSelectionOptions.h @@ -16,7 +16,7 @@ namespace storm { ExtendEnumsWithSelectionField(EquationSolverType, Native, Gmmxx, Eigen, Elimination, Topological) ExtendEnumsWithSelectionField(SmtSolverType, Z3, Mathsat) - ExtendEnumsWithSelectionField(NativeLinearEquationSolverMethod, Jacobi, GaussSeidel, SOR, WalkerChae, Power, SoundValueIteration, IntervalIteration, RationalSearch) + ExtendEnumsWithSelectionField(NativeLinearEquationSolverMethod, Jacobi, GaussSeidel, SOR, WalkerChae, Power, SoundValueIteration, OptimisticValueIteration, IntervalIteration, RationalSearch) ExtendEnumsWithSelectionField(GmmxxLinearEquationSolverMethod, Bicgstab, Qmr, Gmres) ExtendEnumsWithSelectionField(GmmxxLinearEquationSolverPreconditioner, Ilu, Diagonal, None) ExtendEnumsWithSelectionField(EigenLinearEquationSolverMethod, SparseLU, Bicgstab, DGmres, Gmres)