diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index cbb1f30b9..a98bb83f2 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -636,15 +636,18 @@ namespace storm { uint64_t iterations = 0; while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) { + ++iterations; this->soundValueIterationHelper->performIterationStep(dir, b); if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(dir, relevantValuesPtr)) { status = SolverStatus::Converged; + } else { + // Update the status accordingly + if (this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition())) { + status = SolverStatus::TerminatedEarly; + } else if (iterations >= env.solver().minMax().getMaximalNumberOfIterations()) { + status = SolverStatus::MaximalIterationsExceeded; + } } - - // Update environment variables. - ++iterations; - // TODO: Implement custom termination criterion. We would need to add our errors to the stepBoundedX values (only if in second phase) - status = updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::None); // Potentially show progress. this->showProgressIterative(iterations); diff --git a/src/storm/solver/NativeLinearEquationSolver.cpp b/src/storm/solver/NativeLinearEquationSolver.cpp index 9cd6b14d2..04c3d3c76 100644 --- a/src/storm/solver/NativeLinearEquationSolver.cpp +++ b/src/storm/solver/NativeLinearEquationSolver.cpp @@ -603,8 +603,8 @@ namespace storm { converged = true; } - // todo: custom termination check - // terminate = .... + // Check whether we terminate early. + terminate = this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition()); // Update environment variables. ++iterations; diff --git a/src/storm/solver/TerminationCondition.cpp b/src/storm/solver/TerminationCondition.cpp index 745cc8a51..dcd88392b 100644 --- a/src/storm/solver/TerminationCondition.cpp +++ b/src/storm/solver/TerminationCondition.cpp @@ -4,12 +4,18 @@ #include "storm/adapters/RationalFunctionAdapter.h" #include "storm/utility/macros.h" +#include "storm/exceptions/InvalidArgumentException.h" namespace storm { namespace solver { template - bool NoTerminationCondition::terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee) const { + bool TerminationCondition::terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee) const { + return terminateNow([¤tValues] (uint64_t const& i) {return currentValues[i];}, guarantee); + } + + template + bool NoTerminationCondition::terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee) const { return false; } @@ -24,14 +30,17 @@ namespace storm { } template - bool TerminateIfFilteredSumExceedsThreshold::terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee) const { + bool TerminateIfFilteredSumExceedsThreshold::terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee) const { if (guarantee != SolverGuarantee::LessOrEqual) { return false; } - STORM_LOG_ASSERT(currentValues.size() == filter.size(), "Vectors sizes mismatch."); - ValueType currentThreshold = storm::utility::vector::sum_if(currentValues, filter); - return strict ? currentThreshold > this->threshold : currentThreshold >= this->threshold; + ValueType sum = storm::utility::zero(); + for (auto pos : filter) { + sum += valueGetter(pos); + // Exiting this loop early is not possible as values might be negative + } + return strict ? sum > this->threshold : sum >= this->threshold; } template @@ -42,17 +51,47 @@ namespace storm { template TerminateIfFilteredExtremumExceedsThreshold::TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold(filter, threshold, strict), useMinimum(useMinimum) { // Intentionally left empty. + STORM_LOG_THROW(!this->filter.empty(), storm::exceptions::InvalidArgumentException, "Empty Filter; Can not take extremum over empty set."); + cachedExtremumIndex = this->filter.getNextSetIndex(0); } template - bool TerminateIfFilteredExtremumExceedsThreshold::terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee) const { + bool TerminateIfFilteredExtremumExceedsThreshold::terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee) const { if (guarantee != SolverGuarantee::LessOrEqual) { return false; } - STORM_LOG_ASSERT(currentValues.size() == this->filter.size(), "Vectors sizes mismatch."); - ValueType currentValue = useMinimum ? storm::utility::vector::min_if(currentValues, this->filter) : storm::utility::vector::max_if(currentValues, this->filter); - return this->strict ? currentValue > this->threshold : currentValue >= this->threshold; + ValueType extremum = valueGetter(cachedExtremumIndex); + if (useMinimum && (this->strict ? extremum <= this->threshold : extremum < this->threshold)) { + // The extremum can only become smaller so we can return right now. + return false; + } + + if (useMinimum) { + if (this->strict) { + for (auto const& pos : this->filter) { + extremum = std::min(valueGetter(pos), extremum); + if (extremum <= this->threshold) { + cachedExtremumIndex = pos; + return false; + } + } + } else { + for (auto const& pos : this->filter) { + extremum = std::min(valueGetter(pos), extremum); + if (extremum < this->threshold) { + cachedExtremumIndex = pos; + return false; + } + } + } + } else { + for (auto const& pos : this->filter) { + extremum = std::max(valueGetter(pos), extremum); + } + } + + return this->strict ? extremum > this->threshold : extremum >= this->threshold; } template @@ -62,18 +101,47 @@ namespace storm { template TerminateIfFilteredExtremumBelowThreshold::TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold(filter, threshold, strict), useMinimum(useMinimum) { - // Intentionally left empty. + STORM_LOG_THROW(!this->filter.empty(), storm::exceptions::InvalidArgumentException, "Empty Filter; Can not take extremum over empty set."); + cachedExtremumIndex = this->filter.getNextSetIndex(0); } template - bool TerminateIfFilteredExtremumBelowThreshold::terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee) const { + bool TerminateIfFilteredExtremumBelowThreshold::terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee) const { if (guarantee != SolverGuarantee::GreaterOrEqual) { return false; } - STORM_LOG_ASSERT(currentValues.size() == this->filter.size(), "Vectors sizes mismatch."); - ValueType currentValue = useMinimum ? storm::utility::vector::min_if(currentValues, this->filter) : storm::utility::vector::max_if(currentValues, this->filter); - return this->strict ? currentValue < this->threshold : currentValue <= this->threshold; + ValueType extremum = valueGetter(cachedExtremumIndex); + if (!useMinimum && (this->strict ? extremum >= this->threshold : extremum > this->threshold)) { + // The extremum can only become larger so we can return right now. + return false; + } + + if (useMinimum) { + for (auto const& pos : this->filter) { + extremum = std::min(valueGetter(pos), extremum); + } + } else { + if (this->strict) { + for (auto const& pos : this->filter) { + extremum = std::max(valueGetter(pos), extremum); + if (extremum >= this->threshold) { + cachedExtremumIndex = pos; + return false; + } + } + } else { + for (auto const& pos : this->filter) { + extremum = std::max(valueGetter(pos), extremum); + if (extremum > this->threshold) { + cachedExtremumIndex = pos; + return false; + } + } + } + } + + return this->strict ? extremum < this->threshold : extremum <= this->threshold; } template diff --git a/src/storm/solver/TerminationCondition.h b/src/storm/solver/TerminationCondition.h index a8e21697e..98402d1af 100644 --- a/src/storm/solver/TerminationCondition.h +++ b/src/storm/solver/TerminationCondition.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include "storm/solver/SolverGuarantee.h" #include "storm/storage/BitVector.h" @@ -15,7 +15,8 @@ namespace storm { /*! * Retrieves whether the guarantee provided by the solver for the current result is sufficient to terminate. */ - virtual bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const = 0; + virtual bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const; + virtual bool terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const = 0; /*! * Retrieves whether the termination criterion requires the given guarantee in order to decide termination. @@ -27,7 +28,7 @@ namespace storm { template class NoTerminationCondition : public TerminationCondition { public: - virtual bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + virtual bool terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; }; @@ -36,7 +37,7 @@ namespace storm { public: TerminateIfFilteredSumExceedsThreshold(storm::storage::BitVector const& filter, ValueType const& threshold, bool strict); - bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + bool terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: @@ -50,11 +51,12 @@ namespace storm { public: TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum); - bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + bool terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: bool useMinimum; + mutable uint64_t cachedExtremumIndex; }; template @@ -62,11 +64,12 @@ namespace storm { public: TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum); - bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + bool terminateNow(std::function const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: bool useMinimum; + mutable uint64_t cachedExtremumIndex; }; } } diff --git a/src/storm/solver/helper/SoundValueIterationHelper.cpp b/src/storm/solver/helper/SoundValueIterationHelper.cpp index 4a831e7c7..68bd1f03a 100644 --- a/src/storm/solver/helper/SoundValueIterationHelper.cpp +++ b/src/storm/solver/helper/SoundValueIterationHelper.cpp @@ -294,9 +294,28 @@ namespace storm { << ". Decision value is " << (hasDecisionValue ? decisionValue : storm::utility::zero()) << (hasDecisionValue ? "" : "(none)") << "."); - } + template + bool SoundValueIterationHelper::checkCustomTerminationCondition(storm::solver::TerminationCondition const& condition) { + if (condition.requiresGuarantee(storm::solver::SolverGuarantee::GreaterOrEqual)) { + if (hasUpperBound && condition.terminateNow( + [&](uint64_t const& i) { + return x[i] + y[i] * upperBound; + }, storm::solver::SolverGuarantee::GreaterOrEqual)) { + return true; + } + } else if (condition.requiresGuarantee(storm::solver::SolverGuarantee::LessOrEqual)) { + if (hasLowerBound && condition.terminateNow( + [&](uint64_t const& i) { + return x[i] + y[i] * lowerBound; + }, storm::solver::SolverGuarantee::LessOrEqual)) { + return true; + } + } + return false; + } + template bool SoundValueIterationHelper::checkConvergencePhase1() { // Return true if y ('the probability to stay within the matrix') is < 1 at every entry diff --git a/src/storm/solver/helper/SoundValueIterationHelper.h b/src/storm/solver/helper/SoundValueIterationHelper.h index d9f7bc476..54cdc1fb1 100644 --- a/src/storm/solver/helper/SoundValueIterationHelper.h +++ b/src/storm/solver/helper/SoundValueIterationHelper.h @@ -3,6 +3,7 @@ #include #include "storm/solver/OptimizationDirection.h" +#include "storm/solver/TerminationCondition.h" namespace storm { @@ -61,6 +62,11 @@ namespace storm { */ bool checkConvergenceUpdateBounds(storm::storage::BitVector const* relevantValues = nullptr); + /*! + * Checks whether the provided termination condition triggers termination + */ + bool checkCustomTerminationCondition(storm::solver::TerminationCondition const& condition); + private: enum class InternalOptimizationDirection { @@ -92,7 +98,6 @@ namespace storm { template void checkIfDecisionValueBlocks(); - // Auxiliary helper functions to avoid case distinctions due to different optimization directions template inline bool better(ValueType const& val1, ValueType const& val2) { @@ -119,7 +124,6 @@ namespace storm { return (dir == InternalOptimizationDirection::Maximize) ? minIndex : maxIndex; } - std::vector& x; std::vector& y; std::vector xTmp, yTmp; diff --git a/src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp b/src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp index b00b44c0d..dd6d59cf8 100644 --- a/src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp +++ b/src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp @@ -21,6 +21,18 @@ namespace { return env; } }; + + class DoubleSVIEnvironment { + public: + typedef double ValueType; + static storm::Environment createEnvironment() { + storm::Environment env; + env.solver().minMax().setMethod(storm::solver::MinMaxMethod::SoundValueIteration); + env.solver().minMax().setPrecision(storm::utility::convertNumber(1e-6)); + return env; + } + }; + class RationalPiEnvironment { public: typedef storm::RationalNumber ValueType; @@ -44,6 +56,7 @@ namespace { typedef ::testing::Types< DoubleViEnvironment, + DoubleSVIEnvironment, RationalPiEnvironment > TestingTypes; diff --git a/src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp b/src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp index 0538fa081..e0741fa30 100644 --- a/src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp +++ b/src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp @@ -438,6 +438,8 @@ namespace { TYPED_TEST(MdpPrctlModelCheckerTest, consensus) { std::string formulasString = "Pmax=? [F \"finished\"]"; formulasString += "; Pmax=? [F \"all_coins_equal_1\"]"; + formulasString += "; P<0.8 [F \"all_coins_equal_1\"]"; + formulasString += "; P<0.9 [F \"all_coins_equal_1\"]"; formulasString += "; Rmax=? [F \"all_coins_equal_1\"]"; formulasString += "; Rmin=? [F \"all_coins_equal_1\"]"; formulasString += "; Rmax=? [F \"finished\"]"; @@ -459,15 +461,21 @@ namespace { EXPECT_NEAR(this->parseNumber("57/64"), this->getQuantitativeResultAtInitialState(model, result), this->precision()); result = checker->check(this->env(), tasks[2]); - EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result))); + EXPECT_FALSE(this->getQualitativeResultAtInitialState(model, result)); result = checker->check(this->env(), tasks[3]); - EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result))); + EXPECT_TRUE(this->getQualitativeResultAtInitialState(model, result)); result = checker->check(this->env(), tasks[4]); + EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result))); + + result = checker->check(this->env(), tasks[5]); + EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result))); + + result = checker->check(this->env(), tasks[6]); EXPECT_NEAR(this->parseNumber("75"), this->getQuantitativeResultAtInitialState(model, result), this->precision()); - result = checker->check(this->env(), tasks[5]); + result = checker->check(this->env(), tasks[7]); EXPECT_NEAR(this->parseNumber("48"), this->getQuantitativeResultAtInitialState(model, result), this->precision()); }