diff --git a/resources/3rdparty/sylvan/src/sylvan_mtbdd_storm.c b/resources/3rdparty/sylvan/src/sylvan_mtbdd_storm.c index caf2ad18d..55c78e4f5 100644 --- a/resources/3rdparty/sylvan/src/sylvan_mtbdd_storm.c +++ b/resources/3rdparty/sylvan/src/sylvan_mtbdd_storm.c @@ -606,7 +606,7 @@ TASK_IMPL_2(MTBDD, mtbdd_op_sharpen, MTBDD, a, size_t, p) MTBDD result = mtbdd_storm_rational_number(storm_double_sharpen(mtbdd_getdouble(a), p)); return result; } else if (mtbddnode_gettype(na) == srn_type) { - return mtbdd_storm_rational_number(storm_rational_number_sharpen((storm_rational_number_ptr)mtbdd_getstorm_rational_function_ptr(a), p)); + return mtbdd_storm_rational_number(storm_rational_number_sharpen((storm_rational_number_ptr)mtbdd_getstorm_rational_number_ptr(a), p)); } else { printf("ERROR: Unsupported value type in sharpen.\n"); assert(0); diff --git a/src/storm/solver/SymbolicEquationSolver.cpp b/src/storm/solver/SymbolicEquationSolver.cpp index a76f94873..cf5efc71e 100644 --- a/src/storm/solver/SymbolicEquationSolver.cpp +++ b/src/storm/solver/SymbolicEquationSolver.cpp @@ -62,7 +62,47 @@ namespace storm { } template - storm::dd::Add SymbolicEquationSolver::getLowerBounds() const { + bool SymbolicEquationSolver::hasLowerBound() const { + return static_cast(lowerBound); + } + + template + ValueType const& SymbolicEquationSolver::getLowerBound() const { + return lowerBound.get(); + } + + template + bool SymbolicEquationSolver::hasLowerBounds() const { + return static_cast(lowerBounds); + } + + template + storm::dd::Add const& SymbolicEquationSolver::getLowerBounds() const { + return lowerBounds.get(); + } + + template + bool SymbolicEquationSolver::hasUpperBound() const { + return static_cast(upperBound); + } + + template + ValueType const& SymbolicEquationSolver::getUpperBound() const { + return upperBound.get(); + } + + template + bool SymbolicEquationSolver::hasUpperBounds() const { + return static_cast(upperBounds); + } + + template + storm::dd::Add const& SymbolicEquationSolver::getUpperBounds() const { + return upperBounds.get(); + } + + template + storm::dd::Add SymbolicEquationSolver::getLowerBoundsVector() const { STORM_LOG_THROW(lowerBound || lowerBounds, storm::exceptions::UnmetRequirementException, "Requiring lower bounds, but did not get any."); if (lowerBounds) { return lowerBounds.get(); @@ -72,7 +112,7 @@ namespace storm { } template - storm::dd::Add SymbolicEquationSolver::getUpperBounds() const { + storm::dd::Add SymbolicEquationSolver::getUpperBoundsVector() const { STORM_LOG_THROW(upperBound || upperBounds, storm::exceptions::UnmetRequirementException, "Requiring upper bounds, but did not get any."); if (upperBounds) { return upperBounds.get(); diff --git a/src/storm/solver/SymbolicEquationSolver.h b/src/storm/solver/SymbolicEquationSolver.h index 2e6322cfc..73681d762 100644 --- a/src/storm/solver/SymbolicEquationSolver.h +++ b/src/storm/solver/SymbolicEquationSolver.h @@ -14,22 +14,31 @@ namespace storm { SymbolicEquationSolver() = default; SymbolicEquationSolver(storm::dd::Bdd const& allRows); - void setLowerBounds(storm::dd::Add const& lowerBounds); - void setLowerBound(ValueType const& lowerBound); - void setUpperBounds(storm::dd::Add const& upperBounds); - void setUpperBound(ValueType const& lowerBound); - void setBounds(ValueType const& lowerBound, ValueType const& upperBound); - void setBounds(storm::dd::Add const& lowerBounds, storm::dd::Add const& upperBounds); - + virtual void setLowerBounds(storm::dd::Add const& lowerBounds); + virtual void setLowerBound(ValueType const& lowerBound); + virtual void setUpperBounds(storm::dd::Add const& upperBounds); + virtual void setUpperBound(ValueType const& lowerBound); + virtual void setBounds(ValueType const& lowerBound, ValueType const& upperBound); + virtual void setBounds(storm::dd::Add const& lowerBounds, storm::dd::Add const& upperBounds); + + bool hasLowerBound() const; + ValueType const& getLowerBound() const; + bool hasLowerBounds() const; + storm::dd::Add const& getLowerBounds() const; + bool hasUpperBound() const; + ValueType const& getUpperBound() const; + bool hasUpperBounds() const; + storm::dd::Add const& getUpperBounds() const; + /*! * Retrieves a vector of lower bounds for all values (if any lower bounds are known). */ - storm::dd::Add getLowerBounds() const; + storm::dd::Add getLowerBoundsVector() const; /*! * Retrieves a vector of upper bounds for all values (if any lower bounds are known). */ - storm::dd::Add getUpperBounds() const; + storm::dd::Add getUpperBoundsVector() const; protected: storm::dd::DdManager& getDdManager() const; diff --git a/src/storm/solver/SymbolicMinMaxLinearEquationSolver.cpp b/src/storm/solver/SymbolicMinMaxLinearEquationSolver.cpp index 4cb05ee90..2df70a371 100644 --- a/src/storm/solver/SymbolicMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/SymbolicMinMaxLinearEquationSolver.cpp @@ -226,7 +226,7 @@ namespace storm { template template typename std::enable_if::value && storm::NumberTraits::IsExact, storm::dd::Add>::type SymbolicMinMaxLinearEquationSolver::solveEquationsRationalSearchHelper(storm::solver::OptimizationDirection const& dir, storm::dd::Add const& x, storm::dd::Add const& b) const { - return solveEquationsRationalSearchHelper(dir, *this, *this, b, this->getLowerBounds(), b); + return solveEquationsRationalSearchHelper(dir, *this, *this, b, this->getLowerBoundsVector(), b); } template @@ -236,7 +236,7 @@ namespace storm { storm::dd::Add rationalB = b.template toValueType(); SymbolicMinMaxLinearEquationSolver rationalSolver(this->A.template toValueType(), this->allRows, this->illegalMask, this->rowMetaVariables, this->columnMetaVariables, this->choiceVariables, this->rowColumnMetaVariablePairs, std::make_unique>()); - storm::dd::Add rationalResult = solveEquationsRationalSearchHelper(dir, rationalSolver, *this, rationalB, this->getLowerBounds(), b); + storm::dd::Add rationalResult = solveEquationsRationalSearchHelper(dir, rationalSolver, *this, rationalB, this->getLowerBoundsVector(), b); return rationalResult.template toValueType(); } @@ -248,7 +248,7 @@ namespace storm { storm::dd::Add rationalResult; storm::dd::Add impreciseX; try { - impreciseX = this->getLowerBounds().template toValueType(); + impreciseX = this->getLowerBoundsVector().template toValueType(); storm::dd::Add impreciseB = b.template toValueType(); SymbolicMinMaxLinearEquationSolver impreciseSolver(this->A.template toValueType(), this->allRows, this->illegalMask, this->rowMetaVariables, this->columnMetaVariables, this->choiceVariables, this->rowColumnMetaVariablePairs, std::make_unique>()); @@ -276,7 +276,7 @@ namespace storm { if (this->hasInitialScheduler()) { localX = solveEquationsWithScheduler(this->getInitialScheduler(), x, b); } else { - localX = this->getLowerBounds(); + localX = this->getLowerBoundsVector(); } ValueIterationResult viResult = performValueIteration(dir, localX, b, this->getSettings().getPrecision(), this->getSettings().getRelativeTerminationCriterion(), this->settings.getMaximalNumberOfIterations()); @@ -294,6 +294,7 @@ namespace storm { storm::dd::Add SymbolicMinMaxLinearEquationSolver::solveEquationsWithScheduler(storm::dd::Bdd const& scheduler, storm::dd::Add const& x, storm::dd::Add const& b) const { std::unique_ptr> solver = linearEquationSolverFactory->create(this->allRows, this->rowMetaVariables, this->columnMetaVariables, this->rowColumnMetaVariablePairs); + this->forwardBounds(*solver); storm::dd::Add diagonal = (storm::utility::dd::getRowColumnDiagonal(x.getDdManager(), this->rowColumnMetaVariablePairs) && this->allRows).template toAdd(); return solveEquationsWithScheduler(*solver, scheduler, x, b, diagonal); } @@ -327,6 +328,7 @@ namespace storm { // Initialize linear equation solver. std::unique_ptr> linearEquationSolver = linearEquationSolverFactory->create(this->allRows, this->rowMetaVariables, this->columnMetaVariables, this->rowColumnMetaVariablePairs); + this->forwardBounds(*linearEquationSolver); // Iteratively solve and improve the scheduler. while (!converged && iterations < this->settings.getMaximalNumberOfIterations()) { @@ -452,6 +454,22 @@ namespace storm { SymbolicMinMaxLinearEquationSolverSettings const& SymbolicMinMaxLinearEquationSolver::getSettings() const { return settings; } + + template + void SymbolicMinMaxLinearEquationSolver::forwardBounds(storm::solver::SymbolicLinearEquationSolver& solver) const { + if (this->hasLowerBound()) { + solver.setLowerBound(this->getLowerBound()); + } + if (this->hasLowerBounds()) { + solver.setLowerBounds(this->getLowerBounds()); + } + if (this->hasUpperBound()) { + solver.setUpperBound(this->getUpperBound()); + } + if (this->hasUpperBounds()) { + solver.setUpperBounds(this->getUpperBounds()); + } + } template MinMaxLinearEquationSolverRequirements SymbolicMinMaxLinearEquationSolverFactory::getRequirements(EquationSystemType const& equationSystemType, boost::optional const& direction) const { diff --git a/src/storm/solver/SymbolicMinMaxLinearEquationSolver.h b/src/storm/solver/SymbolicMinMaxLinearEquationSolver.h index 4e5fadbfd..6853582f1 100644 --- a/src/storm/solver/SymbolicMinMaxLinearEquationSolver.h +++ b/src/storm/solver/SymbolicMinMaxLinearEquationSolver.h @@ -213,6 +213,12 @@ namespace storm { // A scheduler that specifies with which schedulers to start. boost::optional> initialScheduler; + + private: + /*! + * Forwards the known bounds of this solver to the given linear equation solver. + */ + void forwardBounds(storm::solver::SymbolicLinearEquationSolver& solver) const; }; template diff --git a/src/storm/solver/SymbolicNativeLinearEquationSolver.cpp b/src/storm/solver/SymbolicNativeLinearEquationSolver.cpp index 5f3613100..5c6da4572 100644 --- a/src/storm/solver/SymbolicNativeLinearEquationSolver.cpp +++ b/src/storm/solver/SymbolicNativeLinearEquationSolver.cpp @@ -185,7 +185,7 @@ namespace storm { template storm::dd::Add SymbolicNativeLinearEquationSolver::solveEquationsPower(storm::dd::Add const& x, storm::dd::Add const& b) const { - PowerIterationResult result = performPowerIteration(this->getLowerBounds(), b, this->getSettings().getPrecision(), this->getSettings().getRelativeTerminationCriterion(), this->getSettings().getMaximalNumberOfIterations()); + PowerIterationResult result = performPowerIteration(this->getLowerBoundsVector(), b, this->getSettings().getPrecision(), this->getSettings().getRelativeTerminationCriterion(), this->getSettings().getMaximalNumberOfIterations()); if (result.status == SolverStatus::Converged) { STORM_LOG_INFO("Iterative solver (power iteration) converged in " << result.iterations << " iterations."); @@ -195,7 +195,7 @@ namespace storm { return result.values; } - + template bool SymbolicNativeLinearEquationSolver::isSolutionFixedPoint(storm::dd::Add const& x, storm::dd::Add const& b) const { storm::dd::Add xAsColumn = x.swapVariables(this->rowColumnMetaVariablePairs); @@ -273,7 +273,7 @@ namespace storm { template template typename std::enable_if::value && storm::NumberTraits::IsExact, storm::dd::Add>::type SymbolicNativeLinearEquationSolver::solveEquationsRationalSearchHelper(storm::dd::Add const& x, storm::dd::Add const& b) const { - return solveEquationsRationalSearchHelper(*this, *this, b, this->getLowerBounds(), b); + return solveEquationsRationalSearchHelper(*this, *this, b, this->getLowerBoundsVector(), b); } template @@ -283,7 +283,7 @@ namespace storm { storm::dd::Add rationalB = b.template toValueType(); SymbolicNativeLinearEquationSolver rationalSolver(this->A.template toValueType(), this->allRows, this->rowMetaVariables, this->columnMetaVariables, this->rowColumnMetaVariablePairs); - storm::dd::Add rationalResult = solveEquationsRationalSearchHelper(rationalSolver, *this, rationalB, this->getLowerBounds(), b); + storm::dd::Add rationalResult = solveEquationsRationalSearchHelper(rationalSolver, *this, rationalB, this->getLowerBoundsVector(), b); return rationalResult.template toValueType(); } @@ -295,7 +295,7 @@ namespace storm { storm::dd::Add rationalResult; storm::dd::Add impreciseX; try { - impreciseX = this->getLowerBounds().template toValueType(); + impreciseX = this->getLowerBoundsVector().template toValueType(); storm::dd::Add impreciseB = b.template toValueType(); SymbolicNativeLinearEquationSolver impreciseSolver(this->A.template toValueType(), this->allRows, this->rowMetaVariables, this->columnMetaVariables, this->rowColumnMetaVariablePairs); diff --git a/src/test/storm/storage/SylvanDdTest.cpp b/src/test/storm/storage/SylvanDdTest.cpp index 48778a174..ddb937590 100644 --- a/src/test/storm/storage/SylvanDdTest.cpp +++ b/src/test/storm/storage/SylvanDdTest.cpp @@ -855,6 +855,26 @@ TEST(SylvanDd, AddSharpenTest) { ASSERT_EQ(storm::utility::convertNumber(std::string("19/10")), sharpened.getValue(metaVariableToValueMap)); } +TEST(SylvanDd, AddRationalSharpenTest) { + std::shared_ptr> manager(new storm::dd::DdManager()); + std::pair x = manager->addMetaVariable("x", 1, 9); + + storm::dd::Add dd = manager->template getAddOne(); + ASSERT_NO_THROW(dd.setValue(x.first, 4, storm::utility::convertNumber(1.89999999))); + ASSERT_EQ(2ul, dd.getLeafCount()); + + storm::dd::Add sharpened = dd.sharpenKwekMehlhorn(1); + + std::map metaVariableToValueMap; + metaVariableToValueMap.emplace(x.first, 4); + + sharpened = dd.sharpenKwekMehlhorn(1); + ASSERT_EQ(storm::utility::convertNumber(std::string("9/5")), sharpened.getValue(metaVariableToValueMap)); + + sharpened = dd.sharpenKwekMehlhorn(2); + ASSERT_EQ(storm::utility::convertNumber(std::string("19/10")), sharpened.getValue(metaVariableToValueMap)); +} + TEST(SylvanDd, AddToRationalTest) { std::shared_ptr> manager(new storm::dd::DdManager()); std::pair x = manager->addMetaVariable("x", 1, 9);