diff --git a/resources/3rdparty/sylvan/src/storm_function_wrapper.cpp b/resources/3rdparty/sylvan/src/storm_function_wrapper.cpp index 4c7765781..9f1bebc2a 100644 --- a/resources/3rdparty/sylvan/src/storm_function_wrapper.cpp +++ b/resources/3rdparty/sylvan/src/storm_function_wrapper.cpp @@ -3,9 +3,16 @@ #include #include #include +#include +#include #include "src/adapters/CarlAdapter.h" #include "sylvan_storm_rational_function.h" +#include +#include +#include +#include + #undef DEBUG_STORM_FUNCTION_WRAPPER #ifdef DEBUG_STORM_FUNCTION_WRAPPER @@ -206,16 +213,44 @@ void print_storm_rational_function_to_file(storm_rational_function_ptr a, FILE* fprintf(out, "%s", s.c_str()); } -MTBDD storm_rational_function_leaf_parameter_replacement(uint64_t node_value, uint32_t node_type, void* context) { - if (node_type != sylvan_storm_rational_function_get_type()) { - // - } else { - // +MTBDD testiTest(storm::RationalFunction const& currentFunction, std::map>> const& replacements) { + if (currentFunction.isConstant()) { + std::cout << "Is constant, returning f = " << currentFunction << std::endl; + return mtbdd_storm_rational_function((storm_rational_function_ptr)¤tFunction); + } + + std::set variablesInFunction = currentFunction.gatherVariables(); + std::cout << "Entered testiTest with f = " << currentFunction << " and " << variablesInFunction.size() << " Variables left." << std::endl; + + std::map>>::const_iterator it = replacements.cbegin(); + std::map>>::const_iterator end = replacements.cend(); + + // Walking the (ordered) map enforces an ordering on the MTBDD + for (; it != end; ++it) { + if (variablesInFunction.find(it->first) != variablesInFunction.cend()) { + std::cout << "Replacing variable!" << std::endl; + std::map highReplacement = {{it->first, it->second.second.first}}; + std::map lowReplacement = {{it->first, it->second.second.second}}; + std::cout << "High Function = " << currentFunction.substitute(highReplacement) << " w. replc = " << it->second.second.first << std::endl; + MTBDD high = testiTest(currentFunction.substitute(highReplacement), replacements); + std::cout << "Low Function = " << currentFunction.substitute(lowReplacement) << " w. replc = " << it->second.second.second << std::endl; + MTBDD low = testiTest(currentFunction.substitute(lowReplacement), replacements); + LACE_ME + return mtbdd_ite(mtbdd_ithvar(it->second.first), high, low); + } } - (void)node_value; - (void)node_type; - (void)context; + std::cout << "Found no variable, returning..." << std::endl; + return mtbdd_storm_rational_function((storm_rational_function_ptr)¤tFunction); +} + + +MTBDD storm_rational_function_leaf_parameter_replacement(MTBDD dd, storm_rational_function_ptr a, void* context) { + storm::RationalFunction& srf_a = *(storm::RationalFunction*)a; + if (srf_a.isConstant()) { + return dd; + } - return mtbdd_invalid; + std::map>>* replacements = (std::map>>*)context; + return testiTest(srf_a, *replacements); } diff --git a/resources/3rdparty/sylvan/src/storm_function_wrapper.h b/resources/3rdparty/sylvan/src/storm_function_wrapper.h index 0530dc584..40c31dd0f 100644 --- a/resources/3rdparty/sylvan/src/storm_function_wrapper.h +++ b/resources/3rdparty/sylvan/src/storm_function_wrapper.h @@ -31,7 +31,7 @@ void print_storm_rational_function_to_file(storm_rational_function_ptr a, FILE* int storm_rational_function_is_zero(storm_rational_function_ptr a); -MTBDD storm_rational_function_leaf_parameter_replacement(uint64_t node_value, uint32_t node_type, void* context); +MTBDD storm_rational_function_leaf_parameter_replacement(MTBDD dd, storm_rational_function_ptr a, void* context); #ifdef __cplusplus } diff --git a/resources/3rdparty/sylvan/src/sylvan_obj_mtbdd_storm.hpp b/resources/3rdparty/sylvan/src/sylvan_obj_mtbdd_storm.hpp index d71aaa232..26d79732f 100644 --- a/resources/3rdparty/sylvan/src/sylvan_obj_mtbdd_storm.hpp +++ b/resources/3rdparty/sylvan/src/sylvan_obj_mtbdd_storm.hpp @@ -30,6 +30,8 @@ Mtbdd DivideRF(const Mtbdd &other) const; Mtbdd AbstractPlusRF(const BddSet &variables) const; + + Mtbdd ReplaceLeavesRF(void* context) const; #endif /** diff --git a/resources/3rdparty/sylvan/src/sylvan_obj_storm.cpp b/resources/3rdparty/sylvan/src/sylvan_obj_storm.cpp index 570f29626..008114e20 100644 --- a/resources/3rdparty/sylvan/src/sylvan_obj_storm.cpp +++ b/resources/3rdparty/sylvan/src/sylvan_obj_storm.cpp @@ -57,6 +57,11 @@ Mtbdd Mtbdd::AbstractPlusRF(const BddSet &variables) const { return sylvan_storm_rational_function_abstract_plus(mtbdd, variables.set.bdd); } +Mtbdd Mtbdd::ReplaceLeavesRF(void* context) const { + LACE_ME; + return sylvan_storm_rational_function_replace_leaves(mtbdd, (size_t)context); +} + #endif Mtbdd diff --git a/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.c b/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.c index 4026d389b..ae08ee588 100644 --- a/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.c +++ b/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.c @@ -380,7 +380,7 @@ TASK_IMPL_2(MTBDD, sylvan_storm_rational_function_op_neg, MTBDD, dd, size_t, p) /** * Operation "replace leaves" for one storm::RationalFunction MTBDD */ -TASK_IMPL_2(MTBDD, sylvan_storm_rational_function_op_replace_leaves, MTBDD, dd, void*, context) +TASK_IMPL_2(MTBDD, sylvan_storm_rational_function_op_replace_leaves, MTBDD, dd, size_t, context) { LOG_I("task_impl_2 op_replace") /* Handle partial functions */ @@ -388,7 +388,12 @@ TASK_IMPL_2(MTBDD, sylvan_storm_rational_function_op_replace_leaves, MTBDD, dd, /* Compute result for leaf */ if (mtbdd_isleaf(dd)) { - return storm_rational_function_leaf_parameter_replacement(mtbdd_getvalue(dd), mtbdd_gettype(dd), context); + if (mtbdd_gettype(dd) != sylvan_storm_rational_function_type) { + assert(0); + } + + storm_rational_function_ptr mdd = (storm_rational_function_ptr)mtbdd_getvalue(dd); + return storm_rational_function_leaf_parameter_replacement(dd, mdd, (void*)context); } return mtbdd_invalid; diff --git a/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.h b/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.h index b64761955..458bc5002 100644 --- a/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.h +++ b/resources/3rdparty/sylvan/src/sylvan_storm_rational_function.h @@ -116,12 +116,12 @@ typedef MTBDD (*mtbddLeaveReplacementFunction)(uint64_t, uint32_t, void*); /** * Operation "replace" for one storm::RationalFunction MTBDD */ -TASK_DECL_2(MTBDD, sylvan_storm_rational_function_op_replace_leaves, MTBDD, void*) +TASK_DECL_2(MTBDD, sylvan_storm_rational_function_op_replace_leaves, MTBDD, size_t) /** * Compute the MTBDD that arises from a after calling the mtbddLeaveReplacementFunction on each leaf. */ -#define sylvan_storm_rational_function_replace_leaves(a, func, ctx) mtbdd_uapply(a, TASK(sylvan_storm_rational_function_op_replace_leaves), ctx) +#define sylvan_storm_rational_function_replace_leaves(a, ctx) mtbdd_uapply(a, TASK(sylvan_storm_rational_function_op_replace_leaves), ctx) #ifdef __cplusplus } diff --git a/src/storage/dd/Add.cpp b/src/storage/dd/Add.cpp index 723b9d088..b1295d663 100644 --- a/src/storage/dd/Add.cpp +++ b/src/storage/dd/Add.cpp @@ -784,6 +784,34 @@ namespace storm { return internalAdd; } +#ifdef STORM_HAVE_CARL + template + Add Add::replaceLeaves(std::map>> const& replacementMap) const { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Not yet implemented: replaceLeaves"); + } + + template<> + Add Add::replaceLeaves(std::map>> const& replacementMap) const { + std::map>> internalReplacementMap; + std::set containedMetaVariables = this->getContainedMetaVariables(); + + std::map>>::const_iterator it = replacementMap.cbegin(); + std::map>>::const_iterator end = replacementMap.cend(); + + for (; it != end; ++it) { + DdMetaVariable const& metaVariable = this->getDdManager().getMetaVariable(it->second.first); + STORM_LOG_THROW(metaVariable.getNumberOfDdVariables() == 1, storm::exceptions::InvalidArgumentException, "Cannot use MetaVariable with more then one internal DD variable."); + + auto const& ddVariable = metaVariable.getDdVariables().at(0); + + internalReplacementMap.insert(std::make_pair(it->first, std::make_pair(ddVariable.getIndex(), it->second.second))); + containedMetaVariables.insert(it->second.first); + } + + return Add(this->getDdManager(), internalAdd.replaceLeaves(internalReplacementMap), containedMetaVariables); + } +#endif + template class Add; template class Add; diff --git a/src/storage/dd/Add.h b/src/storage/dd/Add.h index 6185cee96..9cda98dd2 100644 --- a/src/storage/dd/Add.h +++ b/src/storage/dd/Add.h @@ -14,6 +14,9 @@ #include "src/storage/dd/cudd/CuddAddIterator.h" #include "src/storage/dd/sylvan/SylvanAddIterator.h" +#include "storm-config.h" +#include "src/adapters/CarlAdapter.h" + namespace storm { namespace dd { template @@ -243,6 +246,16 @@ namespace storm { */ Add maximum(Add const& other) const; +#ifdef STORM_HAVE_CARL + /*! + * Replaces the leaves in this MTBDD, using the supplied variable replacement map. + * + * @param replacementMap The variable replacement map. + * @return The resulting function represented as an ADD. + */ + Add replaceLeaves(std::map>> const& replacementMap) const; +#endif + /*! * Sum-abstracts from the given meta variables. * diff --git a/src/storage/dd/sylvan/InternalSylvanAdd.cpp b/src/storage/dd/sylvan/InternalSylvanAdd.cpp index 4d69a296b..9c7701936 100644 --- a/src/storage/dd/sylvan/InternalSylvanAdd.cpp +++ b/src/storage/dd/sylvan/InternalSylvanAdd.cpp @@ -289,6 +289,19 @@ namespace storm { } #endif + template + InternalAdd InternalAdd::replaceLeaves(std::map>> const& replacementMap) const { + STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "Not yet implemented: replaceLeaves"); + } + +#ifdef STORM_HAVE_CARL + template<> + InternalAdd InternalAdd::replaceLeaves(std::map>> const& replacementMap) const { + return InternalAdd(ddManager, this->sylvanMtbdd.ReplaceLeavesRF((void*)&replacementMap)); + } +#endif + + template InternalAdd InternalAdd::sumAbstract(InternalBdd const& cube) const { return InternalAdd(ddManager, this->sylvanMtbdd.AbstractPlus(cube.sylvanBdd)); diff --git a/src/storage/dd/sylvan/InternalSylvanAdd.h b/src/storage/dd/sylvan/InternalSylvanAdd.h index 39c428679..ab5239a39 100644 --- a/src/storage/dd/sylvan/InternalSylvanAdd.h +++ b/src/storage/dd/sylvan/InternalSylvanAdd.h @@ -248,6 +248,16 @@ namespace storm { */ InternalAdd maximum(InternalAdd const& other) const; +#ifdef STORM_HAVE_CARL + /*! + * Replaces the leaves in this MTBDD, using the supplied variable replacement map. + * + * @param replacementMap The variable replacement map. + * @return The resulting function represented as an ADD. + */ + InternalAdd replaceLeaves(std::map>> const& replacementMap) const; +#endif + /*! * Sum-abstracts from the given cube. * diff --git a/test/functional/storage/SylvanDdTest.cpp b/test/functional/storage/SylvanDdTest.cpp index 46be6c67a..83e7e3d80 100644 --- a/test/functional/storage/SylvanDdTest.cpp +++ b/test/functional/storage/SylvanDdTest.cpp @@ -414,6 +414,322 @@ TEST(SylvanDd, EncodingTest) { } #ifdef STORM_HAVE_CARL +TEST(SylvanDd, RationalFunctionLeaveReplacementNonVariable) { + std::shared_ptr> manager(new storm::dd::DdManager()); + storm::dd::Add zero; + ASSERT_NO_THROW(zero = manager->template getAddZero()); + + std::map>> replacementMap; + storm::dd::Add zeroReplacementResult = zero.replaceLeaves(replacementMap); + + EXPECT_EQ(0ul, zeroReplacementResult.getNonZeroCount()); + EXPECT_EQ(1ul, zeroReplacementResult.getLeafCount()); + EXPECT_EQ(1ul, zeroReplacementResult.getNodeCount()); + EXPECT_TRUE(zeroReplacementResult == zero); +} + +TEST(SylvanDd, RationalFunctionLeaveReplacementSimpleVariable) { + std::shared_ptr> manager(new storm::dd::DdManager()); + + // The cache that is used in case the underlying type needs a cache. + std::shared_ptr>> cache = std::make_shared>>(); + + storm::dd::Add function; + carl::Variable x = carl::freshRealVariable("x"); + storm::RationalFunction variableX = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(x), cache)); + ASSERT_NO_THROW(function = manager->template getConstant(variableX)); + + std::pair xExpr; + ASSERT_NO_THROW(xExpr = manager->addMetaVariable("x", 0, 1)); + + std::map>> replacementMap; + storm::RationalNumber rnOneThird = storm::RationalNumber(1) / storm::RationalNumber(3); + storm::RationalNumber rnTwoThird = storm::RationalNumber(2) / storm::RationalNumber(3); + replacementMap.insert(std::make_pair(x, std::make_pair(xExpr.first, std::make_pair(rnOneThird, rnTwoThird)))); + + storm::dd::Add replacedAddSimpleX = function.replaceLeaves(replacementMap); + + storm::dd::Bdd bddX0 = manager->getEncoding(xExpr.first, 0); + storm::dd::Bdd bddX1 = manager->getEncoding(xExpr.first, 1); + + storm::dd::Add complexAdd = + (bddX0.template toAdd() * manager->template getConstant(storm::RationalFunction(rnTwoThird))) + + (bddX1.template toAdd() * manager->template getConstant(storm::RationalFunction(rnOneThird))); + + EXPECT_EQ(2ul, replacedAddSimpleX.getNonZeroCount()); + EXPECT_EQ(2ul, replacedAddSimpleX.getLeafCount()); + EXPECT_EQ(3ul, replacedAddSimpleX.getNodeCount()); + EXPECT_TRUE(replacedAddSimpleX == complexAdd); +} + +TEST(SylvanDd, RationalFunctionLeaveReplacementTwoVariables) { + std::shared_ptr> manager(new storm::dd::DdManager()); + + // The cache that is used in case the underlying type needs a cache. + std::shared_ptr>> cache = std::make_shared>>(); + + storm::dd::Add function; + carl::Variable x = carl::freshRealVariable("x"); + carl::Variable y = carl::freshRealVariable("y"); + storm::RationalFunction variableX = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(x), cache)); + storm::RationalFunction variableY = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(y), cache)); + ASSERT_NO_THROW(function = manager->template getConstant(variableX * variableY)); + + std::pair xExpr; + std::pair yExpr; + ASSERT_NO_THROW(xExpr = manager->addMetaVariable("x", 0, 1)); + ASSERT_NO_THROW(yExpr = manager->addMetaVariable("y", 0, 1)); + + std::map>> replacementMap; + storm::RationalNumber rnOneThird = storm::RationalNumber(1) / storm::RationalNumber(3); + storm::RationalNumber rnTwoThird = storm::RationalNumber(2) / storm::RationalNumber(3); + storm::RationalNumber rnOne = storm::RationalNumber(1); + storm::RationalNumber rnTen = storm::RationalNumber(10); + replacementMap.insert(std::make_pair(x, std::make_pair(xExpr.first, std::make_pair(rnOneThird, rnTwoThird)))); + replacementMap.insert(std::make_pair(y, std::make_pair(yExpr.first, std::make_pair(rnOne, rnTen)))); + + storm::dd::Add replacedAdd = function.replaceLeaves(replacementMap); + + storm::dd::Bdd bddX0 = manager->getEncoding(xExpr.first, 0); + storm::dd::Bdd bddX1 = manager->getEncoding(xExpr.first, 1); + storm::dd::Bdd bddY0 = manager->getEncoding(yExpr.first, 0); + storm::dd::Bdd bddY1 = manager->getEncoding(yExpr.first, 1); + + storm::dd::Add complexAdd = + ((bddX0 && bddY0).template toAdd() * manager->template getConstant(storm::RationalFunction(rnTwoThird * rnTen))) + + ((bddX0 && bddY1).template toAdd() * manager->template getConstant(storm::RationalFunction(rnTwoThird))) + + ((bddX1 && bddY0).template toAdd() * manager->template getConstant(storm::RationalFunction(rnOneThird * rnTen))) + + ((bddX1 && bddY1).template toAdd() * manager->template getConstant(storm::RationalFunction(rnOneThird))); + + EXPECT_EQ(4ul, replacedAdd.getNonZeroCount()); + EXPECT_EQ(4ul, replacedAdd.getLeafCount()); + EXPECT_EQ(7ul, replacedAdd.getNodeCount()); + EXPECT_TRUE(replacedAdd == complexAdd); +} + +TEST(SylvanDd, RationalFunctionBullshitTest) { + // The cache that is used in case the underlying type needs a cache. + std::shared_ptr>> cache = std::make_shared>>(); + + carl::Variable x = carl::freshRealVariable("x"); + carl::Variable y = carl::freshRealVariable("y"); + carl::Variable z = carl::freshRealVariable("z"); + storm::RationalFunction variableX = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(x), cache)); + storm::RationalFunction variableY = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(y), cache)); + storm::RationalFunction variableZ = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(z), cache)); + + storm::RationalFunction constantOne(1); + storm::RationalFunction constantTwo(2); + storm::RationalFunction constantOneDivTwo(constantOne / constantTwo); + storm::RationalFunction tmpFunctionA(constantOneDivTwo); + tmpFunctionA *= variableZ; + tmpFunctionA /= variableY; + storm::RationalFunction tmpFunctionB(variableX); + tmpFunctionB *= variableY; + + //storm::RationalFunction rationalFunction(two * x + x*y + constantOneDivTwo * z / y); + storm::RationalFunction rationalFunction(constantTwo); + rationalFunction *= variableX; + rationalFunction += tmpFunctionB; + rationalFunction += tmpFunctionA; + + std::map replacement = {{x, storm::RationalNumber(2)}}; + storm::RationalFunction subX = rationalFunction.substitute(replacement); + + ASSERT_EQ(subX, storm::RationalFunction(4) + storm::RationalFunction(2) * variableY + tmpFunctionA); +} + +TEST(SylvanDd, RationalFunctionLeaveReplacementComplexFunction) { + std::shared_ptr> manager(new storm::dd::DdManager()); + + // The cache that is used in case the underlying type needs a cache. + std::shared_ptr>> cache = std::make_shared>>(); + + storm::dd::Add function; + carl::Variable x = carl::freshRealVariable("x"); + carl::Variable y = carl::freshRealVariable("y"); + carl::Variable z = carl::freshRealVariable("z"); + storm::RationalFunction variableX = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(x), cache)); + storm::RationalFunction variableY = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(y), cache)); + storm::RationalFunction variableZ = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(z), cache)); + + storm::RationalFunction constantOne(1); + storm::RationalFunction constantTwo(2); + storm::RationalFunction constantOneDivTwo(constantOne / constantTwo); + storm::RationalFunction tmpFunctionA(constantOneDivTwo); + tmpFunctionA *= variableZ; + tmpFunctionA /= variableY; + storm::RationalFunction tmpFunctionB(variableX); + tmpFunctionB *= variableY; + + //storm::RationalFunction rationalFunction(two * x + x*y + constantOneDivTwo * z / y); + storm::RationalFunction rationalFunction(constantTwo); + rationalFunction *= variableX; + rationalFunction += tmpFunctionB; + rationalFunction += tmpFunctionA; + + ASSERT_NO_THROW(function = manager->template getConstant(rationalFunction)); + + std::pair xExpr; + std::pair yExpr; + std::pair zExpr; + ASSERT_NO_THROW(xExpr = manager->addMetaVariable("x", 0, 1)); + ASSERT_NO_THROW(yExpr = manager->addMetaVariable("y", 0, 1)); + ASSERT_NO_THROW(zExpr = manager->addMetaVariable("z", 0, 1)); + + std::map>> replacementMap; + storm::RationalNumber rnTwo(2); + storm::RationalNumber rnThree(3); + storm::RationalNumber rnFive(5); + storm::RationalNumber rnSeven(7); + storm::RationalNumber rnEleven(11); + storm::RationalNumber rnThirteen(13); + replacementMap.insert(std::make_pair(x, std::make_pair(xExpr.first, std::make_pair(rnTwo, rnSeven)))); + replacementMap.insert(std::make_pair(y, std::make_pair(yExpr.first, std::make_pair(rnThree, rnEleven)))); + replacementMap.insert(std::make_pair(z, std::make_pair(zExpr.first, std::make_pair(rnFive, rnThirteen)))); + + storm::dd::Add replacedAdd = function.replaceLeaves(replacementMap); + + storm::dd::Bdd bddX0 = manager->getEncoding(xExpr.first, 0); + storm::dd::Bdd bddX1 = manager->getEncoding(xExpr.first, 1); + storm::dd::Bdd bddY0 = manager->getEncoding(yExpr.first, 0); + storm::dd::Bdd bddY1 = manager->getEncoding(yExpr.first, 1); + storm::dd::Bdd bddZ0 = manager->getEncoding(zExpr.first, 0); + storm::dd::Bdd bddZ1 = manager->getEncoding(zExpr.first, 1); + + auto f = [&](bool x, bool y, bool z) { + storm::RationalNumber result(2); + if (x) { + result *= rnSeven; + } else { + result *= rnTwo; + } + + storm::RationalNumber partTwo(1); + if (x) { + partTwo *= rnSeven; + } else { + partTwo *= rnTwo; + } + if (y) { + partTwo *= rnEleven; + } else { + partTwo *= rnThree; + } + + storm::RationalNumber partThree(1); + if (z) { + partThree *= rnThirteen; + } else { + partThree *= rnFive; + } + if (y) { + partThree /= storm::RationalNumber(2) * rnEleven; + } else { + partThree /= storm::RationalNumber(2) * rnThree; + } + + return result + partTwo + partThree; + }; + + storm::dd::Add complexAdd = + ((bddX0 && (bddY0 && bddZ0)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(false, false, false)))) + + ((bddX0 && (bddY0 && bddZ1)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(false, false, true)))) + + ((bddX0 && (bddY1 && bddZ0)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(false, true, false)))) + + ((bddX0 && (bddY1 && bddZ1)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(false, true, true)))) + + ((bddX1 && (bddY0 && bddZ0)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(true, false, false)))) + + ((bddX1 && (bddY0 && bddZ1)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(true, false, true)))) + + ((bddX1 && (bddY1 && bddZ0)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(true, true, false)))) + + ((bddX1 && (bddY1 && bddZ1)).template toAdd() * manager->template getConstant(storm::RationalFunction(f(true, true, true)))); + + EXPECT_EQ(4ul, replacedAdd.getNonZeroCount()); + EXPECT_EQ(4ul, replacedAdd.getLeafCount()); + EXPECT_EQ(7ul, replacedAdd.getNodeCount()); + EXPECT_TRUE(replacedAdd == complexAdd); + + replacedAdd.exportToDot("sylvan_replacedAddC.dot"); + complexAdd.exportToDot("sylvan_complexAddC.dot"); +} + +TEST(SylvanDd, RationalFunctionLeaveReplacementComplexFunction2) { + std::shared_ptr> manager(new storm::dd::DdManager()); + + // The cache that is used in case the underlying type needs a cache. + std::shared_ptr>> cache = std::make_shared>>(); + + storm::dd::Add function; + carl::Variable x = carl::freshRealVariable("x"); + carl::Variable y = carl::freshRealVariable("y"); + carl::Variable z = carl::freshRealVariable("z"); + + storm::RationalFunction constantOne(1); + storm::RationalFunction constantTwo(2); + + storm::RationalFunction variableX = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(x), cache)); + storm::RationalFunction variableY = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(y), cache)); + storm::RationalFunction variableZ = storm::RationalFunction(typename storm::RationalFunction::PolyType(typename storm::RationalFunction::PolyType::PolyType(z), cache)); + + storm::RationalFunction constantOneDivTwo(constantOne / constantTwo); + storm::RationalFunction tmpFunctionA(constantOneDivTwo); + tmpFunctionA *= variableZ; + tmpFunctionA /= variableY; + storm::RationalFunction tmpFunctionB(variableX); + tmpFunctionB *= variableY; + + //storm::RationalFunction rationalFunction(two * x + x*y + constantOneDivTwo * z / y); + storm::RationalFunction rationalFunction(constantTwo); + rationalFunction *= variableX; + rationalFunction += tmpFunctionB; + rationalFunction += tmpFunctionA; + + ASSERT_NO_THROW(function = manager->template getConstant(rationalFunction)); + + EXPECT_EQ(0ul, function.getNonZeroCount()); + EXPECT_EQ(1ul, function.getLeafCount()); + EXPECT_EQ(1ul, function.getNodeCount()); + + std::pair xExpr; + std::pair yExpr; + std::pair zExpr; + ASSERT_NO_THROW(xExpr = manager->addMetaVariable("x", 0, 1)); + ASSERT_NO_THROW(yExpr = manager->addMetaVariable("y", 0, 1)); + ASSERT_NO_THROW(zExpr = manager->addMetaVariable("z", 0, 1)); + + storm::dd::Bdd bddX0 = manager->getEncoding(xExpr.first, 0); + storm::dd::Bdd bddX1 = manager->getEncoding(xExpr.first, 1); + storm::dd::Bdd bddY0 = manager->getEncoding(yExpr.first, 0); + storm::dd::Bdd bddY1 = manager->getEncoding(yExpr.first, 1); + storm::dd::Bdd bddZ0 = manager->getEncoding(zExpr.first, 0); + storm::dd::Bdd bddZ1 = manager->getEncoding(zExpr.first, 1); + + storm::dd::Add functionSimpleX; + ASSERT_NO_THROW(functionSimpleX = manager->template getConstant(storm::RationalFunction(variableX))); + + std::map>> replacementMapSimpleX; + + storm::RationalNumber rnOneThird = storm::RationalNumber(1) / storm::RationalNumber(3); + storm::RationalNumber rnTwoThird = storm::RationalNumber(2) / storm::RationalNumber(3); + replacementMapSimpleX.insert(std::make_pair(x, std::make_pair(xExpr.first, std::make_pair(rnOneThird, rnTwoThird)))); + + storm::dd::Add replacedAddSimpleX = functionSimpleX.replaceLeaves(replacementMapSimpleX); + replacedAddSimpleX.exportToDot("sylvan_replacementMapSimpleX.dot"); + + std::map>> replacementMap; + + storm::RationalNumber rnMinusOne(-1); + storm::RationalNumber rnOne(1); + storm::RationalNumber rnPointOne = storm::RationalNumber(1) / storm::RationalNumber(10); + storm::RationalNumber rnPointSixSix = storm::RationalNumber(2) / storm::RationalNumber(3); + storm::RationalNumber rnPointFive = storm::RationalNumber(1) / storm::RationalNumber(2); + + replacementMap.insert(std::make_pair(x, std::make_pair(xExpr.first, std::make_pair(rnMinusOne, rnOne)))); + replacementMap.insert(std::make_pair(y, std::make_pair(yExpr.first, std::make_pair(rnPointOne, rnPointSixSix)))); + replacementMap.insert(std::make_pair(z, std::make_pair(zExpr.first, std::make_pair(rnPointFive, rnOne)))); + + storm::dd::Add replacedAdd = function.replaceLeaves(replacementMap); + replacedAdd.exportToDot("sylvan_replaceLeave.dot"); +} + TEST(SylvanDd, RationalFunctionConstants) { std::shared_ptr> manager(new storm::dd::DdManager()); storm::dd::Add zero;