Browse Source

Custom Termination Conditions for sound value iteration

tempestpy_adaptions
TimQu 7 years ago
parent
commit
12f8685080
  1. 13
      src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp
  2. 4
      src/storm/solver/NativeLinearEquationSolver.cpp
  3. 96
      src/storm/solver/TerminationCondition.cpp
  4. 15
      src/storm/solver/TerminationCondition.h
  5. 21
      src/storm/solver/helper/SoundValueIterationHelper.cpp
  6. 8
      src/storm/solver/helper/SoundValueIterationHelper.h
  7. 13
      src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp
  8. 14
      src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp

13
src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp

@ -636,15 +636,18 @@ namespace storm {
uint64_t iterations = 0;
while (status == SolverStatus::InProgress && iterations < env.solver().minMax().getMaximalNumberOfIterations()) {
++iterations;
this->soundValueIterationHelper->performIterationStep(dir, b);
if (this->soundValueIterationHelper->checkConvergenceUpdateBounds(dir, relevantValuesPtr)) {
status = SolverStatus::Converged;
} else {
// Update the status accordingly
if (this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition())) {
status = SolverStatus::TerminatedEarly;
} else if (iterations >= env.solver().minMax().getMaximalNumberOfIterations()) {
status = SolverStatus::MaximalIterationsExceeded;
}
}
// Update environment variables.
++iterations;
// TODO: Implement custom termination criterion. We would need to add our errors to the stepBoundedX values (only if in second phase)
status = updateStatusIfNotConverged(status, x, iterations, env.solver().minMax().getMaximalNumberOfIterations(), SolverGuarantee::None);
// Potentially show progress.
this->showProgressIterative(iterations);

4
src/storm/solver/NativeLinearEquationSolver.cpp

@ -603,8 +603,8 @@ namespace storm {
converged = true;
}
// todo: custom termination check
// terminate = ....
// Check whether we terminate early.
terminate = this->hasCustomTerminationCondition() && this->soundValueIterationHelper->checkCustomTerminationCondition(this->getTerminationCondition());
// Update environment variables.
++iterations;

96
src/storm/solver/TerminationCondition.cpp

@ -4,12 +4,18 @@
#include "storm/adapters/RationalFunctionAdapter.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/InvalidArgumentException.h"
namespace storm {
namespace solver {
template<typename ValueType>
bool NoTerminationCondition<ValueType>::terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee) const {
bool TerminationCondition<ValueType>::terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee) const {
return terminateNow([&currentValues] (uint64_t const& i) {return currentValues[i];}, guarantee);
}
template<typename ValueType>
bool NoTerminationCondition<ValueType>::terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee) const {
return false;
}
@ -24,14 +30,17 @@ namespace storm {
}
template<typename ValueType>
bool TerminateIfFilteredSumExceedsThreshold<ValueType>::terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee) const {
bool TerminateIfFilteredSumExceedsThreshold<ValueType>::terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee) const {
if (guarantee != SolverGuarantee::LessOrEqual) {
return false;
}
STORM_LOG_ASSERT(currentValues.size() == filter.size(), "Vectors sizes mismatch.");
ValueType currentThreshold = storm::utility::vector::sum_if(currentValues, filter);
return strict ? currentThreshold > this->threshold : currentThreshold >= this->threshold;
ValueType sum = storm::utility::zero<ValueType>();
for (auto pos : filter) {
sum += valueGetter(pos);
// Exiting this loop early is not possible as values might be negative
}
return strict ? sum > this->threshold : sum >= this->threshold;
}
template<typename ValueType>
@ -42,17 +51,47 @@ namespace storm {
template<typename ValueType>
TerminateIfFilteredExtremumExceedsThreshold<ValueType>::TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold<ValueType>(filter, threshold, strict), useMinimum(useMinimum) {
// Intentionally left empty.
STORM_LOG_THROW(!this->filter.empty(), storm::exceptions::InvalidArgumentException, "Empty Filter; Can not take extremum over empty set.");
cachedExtremumIndex = this->filter.getNextSetIndex(0);
}
template<typename ValueType>
bool TerminateIfFilteredExtremumExceedsThreshold<ValueType>::terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee) const {
bool TerminateIfFilteredExtremumExceedsThreshold<ValueType>::terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee) const {
if (guarantee != SolverGuarantee::LessOrEqual) {
return false;
}
STORM_LOG_ASSERT(currentValues.size() == this->filter.size(), "Vectors sizes mismatch.");
ValueType currentValue = useMinimum ? storm::utility::vector::min_if(currentValues, this->filter) : storm::utility::vector::max_if(currentValues, this->filter);
return this->strict ? currentValue > this->threshold : currentValue >= this->threshold;
ValueType extremum = valueGetter(cachedExtremumIndex);
if (useMinimum && (this->strict ? extremum <= this->threshold : extremum < this->threshold)) {
// The extremum can only become smaller so we can return right now.
return false;
}
if (useMinimum) {
if (this->strict) {
for (auto const& pos : this->filter) {
extremum = std::min(valueGetter(pos), extremum);
if (extremum <= this->threshold) {
cachedExtremumIndex = pos;
return false;
}
}
} else {
for (auto const& pos : this->filter) {
extremum = std::min(valueGetter(pos), extremum);
if (extremum < this->threshold) {
cachedExtremumIndex = pos;
return false;
}
}
}
} else {
for (auto const& pos : this->filter) {
extremum = std::max(valueGetter(pos), extremum);
}
}
return this->strict ? extremum > this->threshold : extremum >= this->threshold;
}
template<typename ValueType>
@ -62,18 +101,47 @@ namespace storm {
template<typename ValueType>
TerminateIfFilteredExtremumBelowThreshold<ValueType>::TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold<ValueType>(filter, threshold, strict), useMinimum(useMinimum) {
// Intentionally left empty.
STORM_LOG_THROW(!this->filter.empty(), storm::exceptions::InvalidArgumentException, "Empty Filter; Can not take extremum over empty set.");
cachedExtremumIndex = this->filter.getNextSetIndex(0);
}
template<typename ValueType>
bool TerminateIfFilteredExtremumBelowThreshold<ValueType>::terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee) const {
bool TerminateIfFilteredExtremumBelowThreshold<ValueType>::terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee) const {
if (guarantee != SolverGuarantee::GreaterOrEqual) {
return false;
}
STORM_LOG_ASSERT(currentValues.size() == this->filter.size(), "Vectors sizes mismatch.");
ValueType currentValue = useMinimum ? storm::utility::vector::min_if(currentValues, this->filter) : storm::utility::vector::max_if(currentValues, this->filter);
return this->strict ? currentValue < this->threshold : currentValue <= this->threshold;
ValueType extremum = valueGetter(cachedExtremumIndex);
if (!useMinimum && (this->strict ? extremum >= this->threshold : extremum > this->threshold)) {
// The extremum can only become larger so we can return right now.
return false;
}
if (useMinimum) {
for (auto const& pos : this->filter) {
extremum = std::min(valueGetter(pos), extremum);
}
} else {
if (this->strict) {
for (auto const& pos : this->filter) {
extremum = std::max(valueGetter(pos), extremum);
if (extremum >= this->threshold) {
cachedExtremumIndex = pos;
return false;
}
}
} else {
for (auto const& pos : this->filter) {
extremum = std::max(valueGetter(pos), extremum);
if (extremum > this->threshold) {
cachedExtremumIndex = pos;
return false;
}
}
}
}
return this->strict ? extremum < this->threshold : extremum <= this->threshold;
}
template<typename ValueType>

15
src/storm/solver/TerminationCondition.h

@ -1,6 +1,6 @@
#pragma once
#include <vector>
#include <functional>
#include "storm/solver/SolverGuarantee.h"
#include "storm/storage/BitVector.h"
@ -15,7 +15,8 @@ namespace storm {
/*!
* Retrieves whether the guarantee provided by the solver for the current result is sufficient to terminate.
*/
virtual bool terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const = 0;
virtual bool terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const;
virtual bool terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const = 0;
/*!
* Retrieves whether the termination criterion requires the given guarantee in order to decide termination.
@ -27,7 +28,7 @@ namespace storm {
template<typename ValueType>
class NoTerminationCondition : public TerminationCondition<ValueType> {
public:
virtual bool terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
virtual bool terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override;
};
@ -36,7 +37,7 @@ namespace storm {
public:
TerminateIfFilteredSumExceedsThreshold(storm::storage::BitVector const& filter, ValueType const& threshold, bool strict);
bool terminateNow(std::vector<ValueType> const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
bool terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override;
protected:
@ -50,11 +51,12 @@ namespace storm {
public:
TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum);
bool terminateNow(std::vector<ValueType> const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
bool terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override;
protected:
bool useMinimum;
mutable uint64_t cachedExtremumIndex;
};
template<typename ValueType>
@ -62,11 +64,12 @@ namespace storm {
public:
TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum);
bool terminateNow(std::vector<ValueType> const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
bool terminateNow(std::function<ValueType(uint64_t const&)> const& valueGetter, SolverGuarantee const& guarantee = SolverGuarantee::None) const override;
virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override;
protected:
bool useMinimum;
mutable uint64_t cachedExtremumIndex;
};
}
}

21
src/storm/solver/helper/SoundValueIterationHelper.cpp

@ -294,9 +294,28 @@ namespace storm {
<< ". Decision value is "
<< (hasDecisionValue ? decisionValue : storm::utility::zero<ValueType>()) << (hasDecisionValue ? "" : "(none)")
<< ".");
}
template<typename ValueType>
bool SoundValueIterationHelper<ValueType>::checkCustomTerminationCondition(storm::solver::TerminationCondition<ValueType> const& condition) {
if (condition.requiresGuarantee(storm::solver::SolverGuarantee::GreaterOrEqual)) {
if (hasUpperBound && condition.terminateNow(
[&](uint64_t const& i) {
return x[i] + y[i] * upperBound;
}, storm::solver::SolverGuarantee::GreaterOrEqual)) {
return true;
}
} else if (condition.requiresGuarantee(storm::solver::SolverGuarantee::LessOrEqual)) {
if (hasLowerBound && condition.terminateNow(
[&](uint64_t const& i) {
return x[i] + y[i] * lowerBound;
}, storm::solver::SolverGuarantee::LessOrEqual)) {
return true;
}
}
return false;
}
template<typename ValueType>
bool SoundValueIterationHelper<ValueType>::checkConvergencePhase1() {
// Return true if y ('the probability to stay within the matrix') is < 1 at every entry

8
src/storm/solver/helper/SoundValueIterationHelper.h

@ -3,6 +3,7 @@
#include <vector>
#include "storm/solver/OptimizationDirection.h"
#include "storm/solver/TerminationCondition.h"
namespace storm {
@ -61,6 +62,11 @@ namespace storm {
*/
bool checkConvergenceUpdateBounds(storm::storage::BitVector const* relevantValues = nullptr);
/*!
* Checks whether the provided termination condition triggers termination
*/
bool checkCustomTerminationCondition(storm::solver::TerminationCondition<ValueType> const& condition);
private:
enum class InternalOptimizationDirection {
@ -92,7 +98,6 @@ namespace storm {
template<InternalOptimizationDirection dir>
void checkIfDecisionValueBlocks();
// Auxiliary helper functions to avoid case distinctions due to different optimization directions
template<InternalOptimizationDirection dir>
inline bool better(ValueType const& val1, ValueType const& val2) {
@ -119,7 +124,6 @@ namespace storm {
return (dir == InternalOptimizationDirection::Maximize) ? minIndex : maxIndex;
}
std::vector<ValueType>& x;
std::vector<ValueType>& y;
std::vector<ValueType> xTmp, yTmp;

13
src/test/storm-pars/modelchecker/SparseDtmcParameterLiftingTest.cpp

@ -21,6 +21,18 @@ namespace {
return env;
}
};
class DoubleSVIEnvironment {
public:
typedef double ValueType;
static storm::Environment createEnvironment() {
storm::Environment env;
env.solver().minMax().setMethod(storm::solver::MinMaxMethod::SoundValueIteration);
env.solver().minMax().setPrecision(storm::utility::convertNumber<storm::RationalNumber>(1e-6));
return env;
}
};
class RationalPiEnvironment {
public:
typedef storm::RationalNumber ValueType;
@ -44,6 +56,7 @@ namespace {
typedef ::testing::Types<
DoubleViEnvironment,
DoubleSVIEnvironment,
RationalPiEnvironment
> TestingTypes;

14
src/test/storm/modelchecker/MdpPrctlModelCheckerTest.cpp

@ -438,6 +438,8 @@ namespace {
TYPED_TEST(MdpPrctlModelCheckerTest, consensus) {
std::string formulasString = "Pmax=? [F \"finished\"]";
formulasString += "; Pmax=? [F \"all_coins_equal_1\"]";
formulasString += "; P<0.8 [F \"all_coins_equal_1\"]";
formulasString += "; P<0.9 [F \"all_coins_equal_1\"]";
formulasString += "; Rmax=? [F \"all_coins_equal_1\"]";
formulasString += "; Rmin=? [F \"all_coins_equal_1\"]";
formulasString += "; Rmax=? [F \"finished\"]";
@ -459,15 +461,21 @@ namespace {
EXPECT_NEAR(this->parseNumber("57/64"), this->getQuantitativeResultAtInitialState(model, result), this->precision());
result = checker->check(this->env(), tasks[2]);
EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result)));
EXPECT_FALSE(this->getQualitativeResultAtInitialState(model, result));
result = checker->check(this->env(), tasks[3]);
EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result)));
EXPECT_TRUE(this->getQualitativeResultAtInitialState(model, result));
result = checker->check(this->env(), tasks[4]);
EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result)));
result = checker->check(this->env(), tasks[5]);
EXPECT_TRUE(storm::utility::isInfinity(this->getQuantitativeResultAtInitialState(model, result)));
result = checker->check(this->env(), tasks[6]);
EXPECT_NEAR(this->parseNumber("75"), this->getQuantitativeResultAtInitialState(model, result), this->precision());
result = checker->check(this->env(), tasks[5]);
result = checker->check(this->env(), tasks[7]);
EXPECT_NEAR(this->parseNumber("48"), this->getQuantitativeResultAtInitialState(model, result), this->precision());
}

Loading…
Cancel
Save