From b4d8c209cd26d801fcb1d231f575d4176ac6d9f5 Mon Sep 17 00:00:00 2001 From: dehnert Date: Wed, 21 Mar 2018 21:54:24 +0100 Subject: [PATCH] optimizations for game-based abstraction refinement --- src/storm/abstraction/ExpressionTranslator.cpp | 11 ++++++++++- src/storm/abstraction/MenuGameAbstractor.h | 7 +++++++ src/storm/abstraction/MenuGameRefiner.cpp | 8 +++++--- .../abstraction/jani/JaniMenuGameAbstractor.cpp | 13 ++++++++++++- .../abstraction/jani/JaniMenuGameAbstractor.h | 5 +++++ .../abstraction/prism/PrismMenuGameAbstractor.cpp | 15 +++++++++++++-- .../abstraction/prism/PrismMenuGameAbstractor.h | 5 +++++ .../abstraction/GameBasedMdpModelChecker.cpp | 4 ++++ .../storage/expressions/EquivalenceChecker.cpp | 10 ++++++++++ .../storage/expressions/EquivalenceChecker.h | 3 +++ 10 files changed, 74 insertions(+), 7 deletions(-) diff --git a/src/storm/abstraction/ExpressionTranslator.cpp b/src/storm/abstraction/ExpressionTranslator.cpp index fe2d28f97..76dd8a2a3 100644 --- a/src/storm/abstraction/ExpressionTranslator.cpp +++ b/src/storm/abstraction/ExpressionTranslator.cpp @@ -17,7 +17,8 @@ namespace storm { template ExpressionTranslator::ExpressionTranslator(AbstractionInformation& abstractionInformation, std::unique_ptr&& smtSolver) : abstractionInformation(abstractionInformation), equivalenceChecker(std::move(smtSolver)), locationVariables(abstractionInformation.getLocationExpressionVariables()), abstractedVariables(abstractionInformation.getAbstractedVariables()) { - // Intentionally left empty. + + equivalenceChecker.addConstraints(abstractionInformation.getConstraints()); } template @@ -51,6 +52,8 @@ namespace storm { for (uint64_t predicateIndex = 0; predicateIndex < abstractionInformation.get().getNumberOfPredicates(); ++predicateIndex) { if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), expression.toExpression())) { return abstractionInformation.get().encodePredicateAsSource(predicateIndex); + } else if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), !expression.toExpression())) { + return !abstractionInformation.get().encodePredicateAsSource(predicateIndex); } } @@ -108,6 +111,8 @@ namespace storm { for (uint64_t predicateIndex = 0; predicateIndex < abstractionInformation.get().getNumberOfPredicates(); ++predicateIndex) { if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), expression.toExpression())) { return abstractionInformation.get().encodePredicateAsSource(predicateIndex); + } else if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), !expression.toExpression())) { + return !abstractionInformation.get().encodePredicateAsSource(predicateIndex); } } @@ -124,6 +129,8 @@ namespace storm { for (uint64_t predicateIndex = 0; predicateIndex < abstractionInformation.get().getNumberOfPredicates(); ++predicateIndex) { if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), expression.toExpression())) { return abstractionInformation.get().encodePredicateAsSource(predicateIndex); + } else if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), !expression.toExpression())) { + return !abstractionInformation.get().encodePredicateAsSource(predicateIndex); } } @@ -154,6 +161,8 @@ namespace storm { for (uint64_t predicateIndex = 0; predicateIndex < abstractionInformation.get().getNumberOfPredicates(); ++predicateIndex) { if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), expression.toExpression())) { return abstractionInformation.get().encodePredicateAsSource(predicateIndex); + } else if (equivalenceChecker.areEquivalent(abstractionInformation.get().getPredicateByIndex(predicateIndex), !expression.toExpression())) { + return !abstractionInformation.get().encodePredicateAsSource(predicateIndex); } } diff --git a/src/storm/abstraction/MenuGameAbstractor.h b/src/storm/abstraction/MenuGameAbstractor.h index e82c1fb31..646292106 100644 --- a/src/storm/abstraction/MenuGameAbstractor.h +++ b/src/storm/abstraction/MenuGameAbstractor.h @@ -51,6 +51,13 @@ namespace storm { /// Retrieves the number of predicates currently in use. virtual uint64_t getNumberOfPredicates() const = 0; + /*! + * Adds the expression to the ones characterizing terminal states, i.e. states whose transitions are not + * explored. For this to work, appropriate predicates must have been used to refine the abstraction, + * otherwise this will fail. + */ + virtual void addTerminalStates(storm::expressions::Expression const& expression) = 0; + protected: void exportToDot(storm::abstraction::MenuGame const& currentGame, std::string const& filename, storm::dd::Bdd const& highlightStatesBdd, storm::dd::Bdd const& filter) const; }; diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp index e139a357f..cd3522ac0 100644 --- a/src/storm/abstraction/MenuGameRefiner.cpp +++ b/src/storm/abstraction/MenuGameRefiner.cpp @@ -59,6 +59,8 @@ namespace storm { template MenuGameRefiner::MenuGameRefiner(MenuGameAbstractor& abstractor, std::unique_ptr&& smtSolver) : abstractor(abstractor), useInterpolation(storm::settings::getModule().isUseInterpolationSet()), splitAll(false), splitPredicates(false), addedAllGuardsFlag(false), pivotSelectionHeuristic(storm::settings::getModule().getPivotSelectionHeuristic()), splitter(), equivalenceChecker(std::move(smtSolver)) { + equivalenceChecker.addConstraints(abstractor.getAbstractionInformation().getConstraints()); + AbstractionSettings::SplitMode splitMode = storm::settings::getModule().getSplitMode(); splitAll = splitMode == AbstractionSettings::SplitMode::All; splitPredicates = splitMode == AbstractionSettings::SplitMode::NonGuard; @@ -325,7 +327,7 @@ namespace storm { } for (auto const& otherPredicate : otherRefinementPredicates) { for (uint64_t index = 0; index < possibleRefinementPredicates.size(); ++index) { - if (equivalenceChecker.areEquivalent(otherPredicate, possibleRefinementPredicates[index])) { + if (equivalenceChecker.areEquivalentModuloNegation(otherPredicate, possibleRefinementPredicates[index])) { ++refinementPredicateIndexToCount[index]; } } @@ -726,13 +728,13 @@ namespace storm { // set or in the set that is to be added. bool addAtom = true; for (auto const& oldPredicate : abstractionInformation.getPredicates()) { - if (equivalenceChecker.areEquivalent(atom, oldPredicate)) { + if (equivalenceChecker.areEquivalent(atom, oldPredicate) || equivalenceChecker.areEquivalent(atom, !oldPredicate)) { addAtom = false; break; } } for (auto const& addedAtom : cleanedAtoms) { - if (equivalenceChecker.areEquivalent(addedAtom, atom)) { + if (equivalenceChecker.areEquivalent(addedAtom, atom) || equivalenceChecker.areEquivalent(addedAtom, !atom)) { addAtom = false; break; } diff --git a/src/storm/abstraction/jani/JaniMenuGameAbstractor.cpp b/src/storm/abstraction/jani/JaniMenuGameAbstractor.cpp index a3900ab80..0e63bc0a5 100644 --- a/src/storm/abstraction/jani/JaniMenuGameAbstractor.cpp +++ b/src/storm/abstraction/jani/JaniMenuGameAbstractor.cpp @@ -154,8 +154,14 @@ namespace storm { auto auxVariables = abstractionInformation.getAuxVariableSet(0, abstractionInformation.getAuxVariableCount()); variablesToAbstract.insert(auxVariables.begin(), auxVariables.end()); + // Compute which states are non-terminal. + storm::dd::Bdd nonTerminalStates = this->abstractionInformation.getDdManager().getBddOne(); + for (auto const& expression : this->terminalStateExpressions) { + nonTerminalStates &= !this->getStates(expression); + } + // Do a reachability analysis on the raw transition relation. - storm::dd::Bdd transitionRelation = game.bdd.existsAbstract(variablesToAbstract); + storm::dd::Bdd transitionRelation = nonTerminalStates && game.bdd.existsAbstract(variablesToAbstract); storm::dd::Bdd initialStates = initialLocationsBdd && initialStateAbstractor.getAbstractStates(); initialStates.addMetaVariables(abstractionInformation.getSourcePredicateVariables()); storm::dd::Bdd reachableStates = storm::utility::dd::computeReachableStates(initialStates, transitionRelation, abstractionInformation.getSourceVariables(), abstractionInformation.getSuccessorVariables()); @@ -213,6 +219,11 @@ namespace storm { return abstractionInformation.getNumberOfPredicates(); } + template + void JaniMenuGameAbstractor::addTerminalStates(storm::expressions::Expression const& expression) { + terminalStateExpressions.emplace_back(expression); + } + // Explicitly instantiate the class. template class JaniMenuGameAbstractor; template class JaniMenuGameAbstractor; diff --git a/src/storm/abstraction/jani/JaniMenuGameAbstractor.h b/src/storm/abstraction/jani/JaniMenuGameAbstractor.h index 0a4cc36ed..0ef49f908 100644 --- a/src/storm/abstraction/jani/JaniMenuGameAbstractor.h +++ b/src/storm/abstraction/jani/JaniMenuGameAbstractor.h @@ -111,6 +111,8 @@ namespace storm { virtual uint64_t getNumberOfPredicates() const override; + virtual void addTerminalStates(storm::expressions::Expression const& expression) override; + protected: using MenuGameAbstractor::exportToDot; @@ -159,6 +161,9 @@ namespace storm { // A flag storing whether a refinement was performed. bool refinementPerformed; + + // A list of terminal state expressions. + std::vector terminalStateExpressions; }; } } diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp index a93e62e67..e5f2580b4 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp @@ -99,7 +99,7 @@ namespace storm { MenuGame PrismMenuGameAbstractor::abstract() { if (refinementPerformed) { currentGame = buildGame(); - refinementPerformed = true; + refinementPerformed = false; } return *currentGame; } @@ -147,8 +147,14 @@ namespace storm { auto auxVariables = abstractionInformation.getAuxVariableSet(0, abstractionInformation.getAuxVariableCount()); variablesToAbstract.insert(auxVariables.begin(), auxVariables.end()); + // Compute which states are non-terminal. + storm::dd::Bdd nonTerminalStates = this->abstractionInformation.getDdManager().getBddOne(); + for (auto const& expression : this->terminalStateExpressions) { + nonTerminalStates &= !this->getStates(expression); + } + // Do a reachability analysis on the raw transition relation. - storm::dd::Bdd transitionRelation = game.bdd.existsAbstract(variablesToAbstract); + storm::dd::Bdd transitionRelation = nonTerminalStates && game.bdd.existsAbstract(variablesToAbstract); storm::dd::Bdd initialStates = initialStateAbstractor.getAbstractStates(); initialStates.addMetaVariables(abstractionInformation.getSourcePredicateVariables()); storm::dd::Bdd reachableStates = storm::utility::dd::computeReachableStates(initialStates, transitionRelation, abstractionInformation.getSourceVariables(), abstractionInformation.getSuccessorVariables()); @@ -208,6 +214,11 @@ namespace storm { return abstractionInformation.getNumberOfPredicates(); } + template + void PrismMenuGameAbstractor::addTerminalStates(storm::expressions::Expression const& expression) { + terminalStateExpressions.emplace_back(expression); + } + // Explicitly instantiate the class. template class PrismMenuGameAbstractor; template class PrismMenuGameAbstractor; diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h index 931170df9..55845f134 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h @@ -111,6 +111,8 @@ namespace storm { virtual uint64_t getNumberOfPredicates() const override; + virtual void addTerminalStates(storm::expressions::Expression const& expression) override; + protected: using MenuGameAbstractor::exportToDot; @@ -156,6 +158,9 @@ namespace storm { // A flag storing whether a refinement was performed. bool refinementPerformed; + + // A list of terminal state expressions. + std::vector terminalStateExpressions; }; } } diff --git a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp index 6a42b6530..7a58e4f39 100644 --- a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp +++ b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp @@ -342,6 +342,10 @@ namespace storm { } else { abstractor = std::make_shared>(preprocessedModel.asJaniModel(), smtSolverFactory); } + if (!constraintExpression.isTrue()) { + abstractor->addTerminalStates(!constraintExpression); + } + abstractor->addTerminalStates(targetStateExpression); // Create a refiner that can be used to refine the abstraction when needed. storm::abstraction::MenuGameRefiner refiner(*abstractor, smtSolverFactory->create(preprocessedModel.getManager())); diff --git a/src/storm/storage/expressions/EquivalenceChecker.cpp b/src/storm/storage/expressions/EquivalenceChecker.cpp index b8d538656..c00d4fae6 100644 --- a/src/storm/storage/expressions/EquivalenceChecker.cpp +++ b/src/storm/storage/expressions/EquivalenceChecker.cpp @@ -13,6 +13,12 @@ namespace storm { } } + void EquivalenceChecker::addConstraints(std::vector const& constraints) { + for (auto const& constraint : constraints) { + this->smtSolver->add(constraint); + } + } + bool EquivalenceChecker::areEquivalent(storm::expressions::Expression const& first, storm::expressions::Expression const& second) { this->smtSolver->push(); this->smtSolver->add((first && !second) || (!first && second)); @@ -21,5 +27,9 @@ namespace storm { return equivalent; } + bool EquivalenceChecker::areEquivalentModuloNegation(storm::expressions::Expression const& first, storm::expressions::Expression const& second) { + return this->areEquivalent(first, second) || this->areEquivalent(first, !second); + } + } } diff --git a/src/storm/storage/expressions/EquivalenceChecker.h b/src/storm/storage/expressions/EquivalenceChecker.h index 2ec39ba8a..5d60eb50b 100644 --- a/src/storm/storage/expressions/EquivalenceChecker.h +++ b/src/storm/storage/expressions/EquivalenceChecker.h @@ -20,7 +20,10 @@ namespace storm { */ EquivalenceChecker(std::unique_ptr&& smtSolver, boost::optional const& constraint = boost::none); + void addConstraints(std::vector const& constraints); + bool areEquivalent(storm::expressions::Expression const& first, storm::expressions::Expression const& second); + bool areEquivalentModuloNegation(storm::expressions::Expression const& first, storm::expressions::Expression const& second); private: std::unique_ptr smtSolver;