diff --git a/src/adapters/DdExpressionAdapter.cpp b/src/adapters/DdExpressionAdapter.cpp index c3a63843c..e33fb40b9 100644 --- a/src/adapters/DdExpressionAdapter.cpp +++ b/src/adapters/DdExpressionAdapter.cpp @@ -9,7 +9,7 @@ namespace storm { namespace adapters { template - DdExpressionAdapter::DdExpressionAdapter(storm::dd::DdManager const& ddManager, std::map const& variableMapping) : ddManager(ddManager), variableMapping(variableMapping) { + DdExpressionAdapter::DdExpressionAdapter(std::shared_ptr> ddManager, std::map const& variableMapping) : ddManager(ddManager), variableMapping(variableMapping) { // Intentionally left empty. } @@ -118,8 +118,8 @@ namespace storm { template boost::any DdExpressionAdapter::visit(storm::expressions::VariableExpression const& expression) { auto const& variablePair = variableMapping.find(expression.getVariable()); - STORM_LOG_THROW(variablePair != variableMapping.end(), storm::exceptions::InvalidArgumentException, "Cannot translate the given expression, because it contains th variable '" << expression.getVariableName() << "' for which no DD counterpart is known."); - return ddManager.getIdentity(variablePair->second); + STORM_LOG_THROW(variablePair != variableMapping.end(), storm::exceptions::InvalidArgumentException, "Cannot translate the given expression, because it contains the variable '" << expression.getVariableName() << "' for which no DD counterpart is known."); + return ddManager->getIdentity(variablePair->second); } template @@ -152,17 +152,17 @@ namespace storm { template boost::any DdExpressionAdapter::visit(storm::expressions::BooleanLiteralExpression const& expression) { - return ddManager.getConstant(expression.getValue()); + return ddManager->getConstant(expression.getValue()); } template boost::any DdExpressionAdapter::visit(storm::expressions::IntegerLiteralExpression const& expression) { - return ddManager.getConstant(expression.getValue()); + return ddManager->getConstant(expression.getValue()); } template boost::any DdExpressionAdapter::visit(storm::expressions::DoubleLiteralExpression const& expression) { - return ddManager.getConstant(expression.getValue()); + return ddManager->getConstant(expression.getValue()); } // Explicitly instantiate the symbolic expression adapter diff --git a/src/adapters/DdExpressionAdapter.h b/src/adapters/DdExpressionAdapter.h index 6eaff65a8..cba1c422b 100644 --- a/src/adapters/DdExpressionAdapter.h +++ b/src/adapters/DdExpressionAdapter.h @@ -14,7 +14,7 @@ namespace storm { template class DdExpressionAdapter : public storm::expressions::ExpressionVisitor { public: - DdExpressionAdapter(storm::dd::DdManager const& ddManager, std::map const& variableMapping); + DdExpressionAdapter(std::shared_ptr> ddManager, std::map const& variableMapping); storm::dd::Dd translateExpression(storm::expressions::Expression const& expression); @@ -31,10 +31,10 @@ namespace storm { private: // The manager responsible for the DDs built by this adapter. - storm::dd::DdManager const& ddManager; + std::shared_ptr> ddManager; // This member maps the variables used in the expressions to the variables used by the DD manager. - std::map const& variableMapping; + std::map variableMapping; }; } // namespace adapters diff --git a/src/builder/DdPrismModelBuilder.cpp b/src/builder/DdPrismModelBuilder.cpp index be779489e..6c9a31a15 100644 --- a/src/builder/DdPrismModelBuilder.cpp +++ b/src/builder/DdPrismModelBuilder.cpp @@ -673,9 +673,9 @@ namespace storm { } if (program.getModelType() == storm::prism::Program::ModelType::DTMC) { - return std::unique_ptr>(new storm::models::symbolic::Dtmc(generationInfo.manager, reachableStates, initialStates, transitionMatrix, generationInfo.rowMetaVariables, generationInfo.rowExpressionAdapter, generationInfo.columnMetaVariables, generationInfo.columnExpressionAdapter, generationInfo.rowColumnMetaVariablePairs, labelToExpressionMapping, stateAndTransitionRewards ? stateAndTransitionRewards.get().first : boost::optional>(), stateAndTransitionRewards ? stateAndTransitionRewards.get().second : boost::optional>())); + return std::shared_ptr>(new storm::models::symbolic::Dtmc(generationInfo.manager, reachableStates, initialStates, transitionMatrix, generationInfo.rowMetaVariables, generationInfo.rowExpressionAdapter, generationInfo.columnMetaVariables, generationInfo.columnExpressionAdapter, generationInfo.rowColumnMetaVariablePairs, labelToExpressionMapping, stateAndTransitionRewards ? stateAndTransitionRewards.get().first : boost::optional>(), stateAndTransitionRewards ? stateAndTransitionRewards.get().second : boost::optional>())); } else if (program.getModelType() == storm::prism::Program::ModelType::MDP) { - return std::unique_ptr>(new storm::models::symbolic::Mdp(generationInfo.manager, reachableStates, initialStates, transitionMatrix, generationInfo.rowMetaVariables, generationInfo.rowExpressionAdapter, generationInfo.columnMetaVariables, generationInfo.columnExpressionAdapter, generationInfo.rowColumnMetaVariablePairs, generationInfo.allNondeterminismVariables, labelToExpressionMapping, stateAndTransitionRewards ? stateAndTransitionRewards.get().first : boost::optional>(), stateAndTransitionRewards ? stateAndTransitionRewards.get().second : boost::optional>())); + return std::shared_ptr>(new storm::models::symbolic::Mdp(generationInfo.manager, reachableStates, initialStates, transitionMatrix, generationInfo.rowMetaVariables, generationInfo.rowExpressionAdapter, generationInfo.columnMetaVariables, generationInfo.columnExpressionAdapter, generationInfo.rowColumnMetaVariablePairs, generationInfo.allNondeterminismVariables, labelToExpressionMapping, stateAndTransitionRewards ? stateAndTransitionRewards.get().first : boost::optional>(), stateAndTransitionRewards ? stateAndTransitionRewards.get().second : boost::optional>())); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Invalid model type."); } diff --git a/src/builder/DdPrismModelBuilder.h b/src/builder/DdPrismModelBuilder.h index d9619710f..fb6e8c7ff 100644 --- a/src/builder/DdPrismModelBuilder.h +++ b/src/builder/DdPrismModelBuilder.h @@ -135,8 +135,8 @@ namespace storm { // Initializes variables and identity DDs. createMetaVariablesAndIdentities(); - rowExpressionAdapter = std::shared_ptr>(new storm::adapters::DdExpressionAdapter(*manager, variableToRowMetaVariableMap)); - columnExpressionAdapter = std::shared_ptr>(new storm::adapters::DdExpressionAdapter(*manager, variableToColumnMetaVariableMap)); + rowExpressionAdapter = std::shared_ptr>(new storm::adapters::DdExpressionAdapter(manager, variableToRowMetaVariableMap)); + columnExpressionAdapter = std::shared_ptr>(new storm::adapters::DdExpressionAdapter(manager, variableToColumnMetaVariableMap)); } // The program that is currently translated. diff --git a/src/models/symbolic/Model.cpp b/src/models/symbolic/Model.cpp index 64303f487..569f1eb55 100644 --- a/src/models/symbolic/Model.cpp +++ b/src/models/symbolic/Model.cpp @@ -31,6 +31,16 @@ namespace storm { return transitionMatrix.getNonZeroCount(); } + template + storm::dd::DdManager const& Model::getManager() const { + return *manager; + } + + template + storm::dd::DdManager& Model::getManager() { + return *manager; + } + template storm::dd::Dd const& Model::getReachableStates() const { return reachableStates; @@ -43,12 +53,12 @@ namespace storm { template storm::dd::Dd Model::getStates(std::string const& label) const { - return rowExpressionAdapter->translateExpression(labelToExpressionMap.at(label)); + return rowExpressionAdapter->translateExpression(labelToExpressionMap.at(label)) && this->reachableStates; } template storm::dd::Dd Model::getStates(storm::expressions::Expression const& expression) const { - return rowExpressionAdapter->translateExpression(expression); + return rowExpressionAdapter->translateExpression(expression).toBdd() && this->reachableStates; } template @@ -106,6 +116,11 @@ namespace storm { return columnVariables; } + template + std::vector> const& Model::getRowColumnMetaVariablePairs() const { + return rowColumnMetaVariablePairs; + } + template void Model::setTransitionMatrix(storm::dd::Dd const& transitionMatrix) { this->transitionMatrix = transitionMatrix; diff --git a/src/models/symbolic/Model.h b/src/models/symbolic/Model.h index 9422bde79..52dee50cb 100644 --- a/src/models/symbolic/Model.h +++ b/src/models/symbolic/Model.h @@ -68,6 +68,20 @@ namespace storm { virtual uint_fast64_t getNumberOfTransitions() const override; + /*! + * Retrieves the manager responsible for the DDs that represent this model. + * + * @return The manager responsible for the DDs that represent this model. + */ + storm::dd::DdManager const& getManager() const; + + /*! + * Retrieves the manager responsible for the DDs that represent this model. + * + * @return The manager responsible for the DDs that represent this model. + */ + storm::dd::DdManager& getManager(); + /*! * Retrieves the reachable states of the model. * @@ -172,6 +186,13 @@ namespace storm { */ std::set const& getColumnVariables() const; + /*! + * Retrieves the pairs of row and column meta variables. + * + * @return The pairs of row and column meta variables. + */ + std::vector> const& getRowColumnMetaVariablePairs() const; + virtual std::size_t getSizeInBytes() const override; virtual void printModelInformationToStream(std::ostream& out) const override; diff --git a/src/utility/graph.h b/src/utility/graph.h index 437f622cf..828617d25 100644 --- a/src/utility/graph.h +++ b/src/utility/graph.h @@ -7,6 +7,8 @@ #include "utility/OsDetection.h" #include "src/storage/sparse/StateType.h" +#include "src/models/symbolic/DeterministicModel.h" +#include "src/models/symbolic/NondeterministicModel.h" #include "src/models/sparse/DeterministicModel.h" #include "src/models/sparse/NondeterministicModel.h" #include "src/utility/constants.h" @@ -251,6 +253,56 @@ namespace storm { return result; } + /*! + * Computes the set of states that has a positive probability of reaching psi states after only passing + * through phi states before. + * + * @param model The (symbolic) model for which to compute the set of states. + * @param phiStates The phi states of the model. + * @param psiStates The psi states of the model. + * @return All states with positive probability. + */ + template + storm::dd::Dd performProbGreater0(storm::models::symbolic::DeterministicModel const& model, storm::dd::Dd const& phiStates, storm::dd::Dd const& psiStates) { + // Initialize environment for backward search. + storm::dd::DdManager const& manager = model.getManager(); + storm::dd::Dd lastIterationStates = manager.getZero(); + storm::dd::Dd statesWithProbabilityGreater0 = psiStates.toBdd(); + storm::dd::Dd phiStatesBdd = phiStates.toBdd(); + + uint_fast64_t iterations = 0; + storm::dd::Dd transitionMatrixBdd = model.getTransitionMatrix().notZero(); + while (lastIterationStates != statesWithProbabilityGreater0) { + lastIterationStates = statesWithProbabilityGreater0; + statesWithProbabilityGreater0.swapVariables(model.getRowColumnMetaVariablePairs()); + statesWithProbabilityGreater0 &= transitionMatrixBdd; + statesWithProbabilityGreater0 = statesWithProbabilityGreater0.existsAbstract(model.getColumnVariables()); + statesWithProbabilityGreater0 &= phiStatesBdd; + statesWithProbabilityGreater0 |= lastIterationStates; + ++iterations; + } + + return statesWithProbabilityGreater0; + } + + /*! + * Computes the sets of states that have probability 0 or 1, respectively, of satisfying phi until psi in a + * deterministic model. + * + * @param model The (symbolic) model for which to compute the set of states. + * @param phiStates The phi states of the model. + * @param psiStates The psi states of the model. + * @return A pair of DDs that represent all states with probability 0 and 1, respectively. + */ + template + static std::pair, storm::dd::Dd> performProb01(storm::models::symbolic::DeterministicModel const& model, storm::dd::Dd const& phiStates, storm::dd::Dd const& psiStates) { + std::pair, storm::dd::Dd> result; + result.first = performProbGreater0(model, phiStates, psiStates); + result.second = !performProbGreater0(model, !psiStates && model.getReachableStates(), !result.first && model.getReachableStates()) && model.getReachableStates(); + result.first = !result.first && model.getReachableStates(); + return result; + } + /*! * Computes the sets of states that have probability greater 0 of satisfying phi until psi under at least * one possible resolution of non-determinism in a non-deterministic model. Stated differently, diff --git a/test/functional/builder/ExplicitPrismModelBuilderTest.cpp b/test/functional/builder/ExplicitPrismModelBuilderTest.cpp index 3892a1bc8..0ff231130 100644 --- a/test/functional/builder/ExplicitPrismModelBuilderTest.cpp +++ b/test/functional/builder/ExplicitPrismModelBuilderTest.cpp @@ -7,7 +7,7 @@ TEST(ExplicitPrismModelBuilderTest, Dtmc) { storm::prism::Program program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/builder/die.pm"); - std::unique_ptr> model = storm::builder::ExplicitPrismModelBuilder::translateProgram(program); + std::shared_ptr> model = storm::builder::ExplicitPrismModelBuilder::translateProgram(program); EXPECT_EQ(13, model->getNumberOfStates()); EXPECT_EQ(20, model->getNumberOfTransitions()); @@ -35,7 +35,7 @@ TEST(ExplicitPrismModelBuilderTest, Dtmc) { TEST(ExplicitPrismModelBuilderTest, Mdp) { storm::prism::Program program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/builder/two_dice.nm"); - std::unique_ptr> model = storm::builder::ExplicitPrismModelBuilder::translateProgram(program); + std::shared_ptr> model = storm::builder::ExplicitPrismModelBuilder::translateProgram(program); EXPECT_EQ(169, model->getNumberOfStates()); EXPECT_EQ(436, model->getNumberOfTransitions()); diff --git a/test/functional/utility/GraphTest.cpp b/test/functional/utility/GraphTest.cpp new file mode 100644 index 000000000..071466846 --- /dev/null +++ b/test/functional/utility/GraphTest.cpp @@ -0,0 +1,52 @@ +#include "gtest/gtest.h" +#include "storm-config.h" + +#include "src/storage/dd/CuddDd.h" +#include "src/parser/PrismParser.h" +#include "src/models/symbolic/Dtmc.h" +#include "src/models/sparse/Dtmc.h" +#include "src/builder/DdPrismModelBuilder.h" +#include "src/builder/ExplicitPrismModelBuilder.h" +#include "src/utility/graph.h" + +TEST(GraphTest, SymbolicProb01) { + storm::prism::Program program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/builder/crowds-5-5.pm"); + std::shared_ptr> model = storm::builder::DdPrismModelBuilder::translateProgram(program); + + ASSERT_TRUE(model->getType() == storm::models::ModelType::Dtmc); + + std::pair, storm::dd::Dd> statesWithProbability01; + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), model->getReachableStates(), model->getStates("observe0Greater1"))); + EXPECT_EQ(4409, statesWithProbability01.first.getNonZeroCount()); + EXPECT_EQ(1316, statesWithProbability01.second.getNonZeroCount()); + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), model->getReachableStates(), model->getStates("observeIGreater1"))); + EXPECT_EQ(1091, statesWithProbability01.first.getNonZeroCount()); + EXPECT_EQ(4802, statesWithProbability01.second.getNonZeroCount()); + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), model->getReachableStates(), model->getStates("observeOnlyTrueSender"))); + EXPECT_EQ(5829, statesWithProbability01.first.getNonZeroCount()); + EXPECT_EQ(1032, statesWithProbability01.second.getNonZeroCount()); +} + +TEST(GraphTest, ExplicitProb01) { + storm::prism::Program program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/builder/crowds-5-5.pm"); + std::shared_ptr> model = storm::builder::ExplicitPrismModelBuilder::translateProgram(program); + + ASSERT_TRUE(model->getType() == storm::models::ModelType::Dtmc); + + std::pair statesWithProbability01; + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), storm::storage::BitVector(model->getNumberOfStates(), true), model->getStates("observe0Greater1"))); + EXPECT_EQ(4409, statesWithProbability01.first.getNumberOfSetBits()); + EXPECT_EQ(1316, statesWithProbability01.second.getNumberOfSetBits()); + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), storm::storage::BitVector(model->getNumberOfStates(), true), model->getStates("observeIGreater1"))); + EXPECT_EQ(1091, statesWithProbability01.first.getNumberOfSetBits()); + EXPECT_EQ(4802, statesWithProbability01.second.getNumberOfSetBits()); + + ASSERT_NO_THROW(statesWithProbability01 = storm::utility::graph::performProb01(*model->as>(), storm::storage::BitVector(model->getNumberOfStates(), true), model->getStates("observeOnlyTrueSender"))); + EXPECT_EQ(5829, statesWithProbability01.first.getNumberOfSetBits()); + EXPECT_EQ(1032, statesWithProbability01.second.getNumberOfSetBits()); +} \ No newline at end of file diff --git a/test/performance/graph/GraphTest.cpp b/test/performance/graph/GraphTest.cpp index 10f65ea20..7ff4e4989 100644 --- a/test/performance/graph/GraphTest.cpp +++ b/test/performance/graph/GraphTest.cpp @@ -7,29 +7,26 @@ #include "src/models/sparse/Mdp.h" #include "src/models/sparse/Dtmc.h" -TEST(GraphTest, PerformProb01) { +TEST(GraphTest, ExplicitProb01) { std::shared_ptr> abstractModel = storm::parser::AutoParser::parseModel(STORM_CPP_BASE_PATH "/examples/dtmc/crowds/crowds20_5.tra", STORM_CPP_BASE_PATH "/examples/dtmc/crowds/crowds20_5.lab", "", ""); std::shared_ptr> dtmc = abstractModel->as>(); storm::storage::BitVector trueStates(dtmc->getNumberOfStates(), true); - - LOG4CPLUS_WARN(logger, "Computing prob01 (3 times) for crowds/crowds20_5..."); - std::pair prob01(storm::utility::graph::performProb01(*dtmc, trueStates, storm::storage::BitVector(dtmc->getStates("observe0Greater1")))); + std::pair prob01(storm::utility::graph::performProb01(*dtmc, trueStates, dtmc->getStates("observe0Greater1"))); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 1724414ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 46046ull); - prob01 = storm::utility::graph::performProb01(*dtmc, trueStates, storm::storage::BitVector(dtmc->getStates("observeIGreater1"))); + prob01 = storm::utility::graph::performProb01(*dtmc, trueStates, dtmc->getStates("observeIGreater1")); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 574016ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 825797ull); - prob01 = storm::utility::graph::performProb01(*dtmc, trueStates, storm::storage::BitVector(dtmc->getStates("observeOnlyTrueSender"))); + prob01 = storm::utility::graph::performProb01(*dtmc, trueStates, dtmc->getStates("observeOnlyTrueSender")); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 1785309ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 40992ull); - LOG4CPLUS_WARN(logger, "Done."); dtmc = nullptr; @@ -38,9 +35,7 @@ TEST(GraphTest, PerformProb01) { std::shared_ptr> dtmc2 = abstractModel->as>(); trueStates = storm::storage::BitVector(dtmc2->getNumberOfStates(), true); - LOG4CPLUS_WARN(logger, "Computing prob01 for synchronous_leader/leader6_8..."); prob01 = storm::utility::graph::performProb01(*dtmc2, trueStates, storm::storage::BitVector(dtmc2->getStates("elected"))); - LOG4CPLUS_WARN(logger, "Done."); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 0ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 1312334ull); @@ -53,16 +48,12 @@ TEST(GraphTest, PerformProb01MinMax) { std::shared_ptr> mdp = abstractModel->as>(); storm::storage::BitVector trueStates(mdp->getNumberOfStates(), true); - LOG4CPLUS_WARN(logger, "Computing prob01min for asynchronous_leader/leader7..."); std::pair prob01(storm::utility::graph::performProb01Min(*mdp, trueStates, mdp->getStates("elected"))); - LOG4CPLUS_WARN(logger, "Done."); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 0ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 2095783ull); - LOG4CPLUS_WARN(logger, "Computing prob01max for asynchronous_leader/leader7..."); prob01 = storm::utility::graph::performProb01Max(*mdp, trueStates, mdp->getStates("elected")); - LOG4CPLUS_WARN(logger, "Done."); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 0ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 2095783ull); @@ -73,16 +64,12 @@ TEST(GraphTest, PerformProb01MinMax) { std::shared_ptr> mdp2 = abstractModel->as>(); trueStates = storm::storage::BitVector(mdp2->getNumberOfStates(), true); - LOG4CPLUS_WARN(logger, "Computing prob01min for consensus/coin4_6..."); prob01 = storm::utility::graph::performProb01Min(*mdp2, trueStates, mdp2->getStates("finished")); - LOG4CPLUS_WARN(logger, "Done."); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 0ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 63616ull); - LOG4CPLUS_WARN(logger, "Computing prob01max for consensus/coin4_6..."); prob01 = storm::utility::graph::performProb01Max(*mdp2, trueStates, mdp2->getStates("finished")); - LOG4CPLUS_WARN(logger, "Done."); ASSERT_EQ(prob01.first.getNumberOfSetBits(), 0ull); ASSERT_EQ(prob01.second.getNumberOfSetBits(), 63616ull);