From 24eede3e197c4a5b2d1fbb47b56045260481e01b Mon Sep 17 00:00:00 2001 From: dehnert Date: Sun, 27 Nov 2016 17:31:19 +0100 Subject: [PATCH] introduced refinement command to capture a specific refinement --- .../abstraction/AbstractionInformation.cpp | 14 +++-- .../abstraction/AbstractionInformation.h | 17 +++++-- src/storm/abstraction/MenuGameAbstractor.h | 5 +- src/storm/abstraction/MenuGameRefiner.cpp | 44 +++++++++++++--- src/storm/abstraction/MenuGameRefiner.h | 13 ++++- src/storm/abstraction/RefinementCommand.cpp | 27 ++++++++++ src/storm/abstraction/RefinementCommand.h | 36 +++++++++++++ .../prism/PrismMenuGameAbstractor.cpp | 51 +++++++++---------- .../prism/PrismMenuGameAbstractor.h | 34 +++++++++---- 9 files changed, 186 insertions(+), 55 deletions(-) create mode 100644 src/storm/abstraction/RefinementCommand.cpp create mode 100644 src/storm/abstraction/RefinementCommand.h diff --git a/src/storm/abstraction/AbstractionInformation.cpp b/src/storm/abstraction/AbstractionInformation.cpp index b7e61712e..8889e0afb 100644 --- a/src/storm/abstraction/AbstractionInformation.cpp +++ b/src/storm/abstraction/AbstractionInformation.cpp @@ -14,7 +14,7 @@ namespace storm { namespace abstraction { template - AbstractionInformation::AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::shared_ptr> ddManager) : expressionManager(expressionManager), ddManager(ddManager), allPredicateIdentities(ddManager->getBddOne()) { + AbstractionInformation::AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager) : expressionManager(expressionManager), equivalenceChecker(std::move(smtSolver)), ddManager(ddManager), allPredicateIdentities(ddManager->getBddOne()) { // Intentionally left empty. } @@ -40,7 +40,15 @@ namespace storm { } template - uint_fast64_t AbstractionInformation::addPredicate(storm::expressions::Expression const& predicate) { + uint_fast64_t AbstractionInformation::getOrAddPredicate(storm::expressions::Expression const& predicate) { + // Check if we already have an equivalent predicate. + for (uint64_t index = 0; index < predicates.size(); ++index) { + auto const& oldPredicate = predicates[index]; + if (equivalenceChecker.areEquivalent(oldPredicate, predicate)) { + return index; + } + } + std::size_t predicateIndex = predicates.size(); predicateToIndexMap[predicate] = predicateIndex; @@ -68,7 +76,7 @@ namespace storm { std::vector AbstractionInformation::addPredicates(std::vector const& predicates) { std::vector predicateIndices; for (auto const& predicate : predicates) { - predicateIndices.push_back(this->addPredicate(predicate)); + predicateIndices.push_back(this->getOrAddPredicate(predicate)); } return predicateIndices; } diff --git a/src/storm/abstraction/AbstractionInformation.h b/src/storm/abstraction/AbstractionInformation.h index 0a4f0ef03..bfdefbcfb 100644 --- a/src/storm/abstraction/AbstractionInformation.h +++ b/src/storm/abstraction/AbstractionInformation.h @@ -3,11 +3,16 @@ #include #include #include +#include #include "storm/storage/dd/DdType.h" #include "storm/storage/dd/Bdd.h" +#include "storm/solver/SmtSolver.h" + +#include "storm/storage/expressions/EquivalenceChecker.h" + namespace storm { namespace expressions { class ExpressionManager; @@ -29,8 +34,10 @@ namespace storm { * Creates a new abstraction information object. * * @param expressionManager The manager responsible for all variables and expressions during the abstraction process. + * @param smtSolver An SMT solver that is used to detect equivalent predicates. + * @param ddManager The manager responsible for the DDs. */ - AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::shared_ptr> ddManager = std::make_shared>()); + AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager = std::make_shared>()); /*! * Adds the given variable. @@ -69,12 +76,13 @@ namespace storm { std::vector const& getConstraints() const; /*! - * Adds the given predicate. + * Gets the index of a predicate that is equivalent to the provided one. If none exists, the predicate is + * added. * * @param predicate The predicate to add. * @return The index of the newly added predicate in the global list of predicates. */ - uint_fast64_t addPredicate(storm::expressions::Expression const& predicate); + uint_fast64_t getOrAddPredicate(storm::expressions::Expression const& predicate); /*! * Adds the given predicates. @@ -462,6 +470,9 @@ namespace storm { /// A mapping from predicates to their indices in the predicate list. std::unordered_map predicateToIndexMap; + /// An object that can detect equivalence of predicates. + storm::expressions::EquivalenceChecker equivalenceChecker; + /// The current set of predicates used in the abstraction. std::vector predicates; diff --git a/src/storm/abstraction/MenuGameAbstractor.h b/src/storm/abstraction/MenuGameAbstractor.h index 9cc820d30..281264a48 100644 --- a/src/storm/abstraction/MenuGameAbstractor.h +++ b/src/storm/abstraction/MenuGameAbstractor.h @@ -1,10 +1,12 @@ #pragma once #include +#include #include "storm/storage/dd/DdType.h" #include "storm/abstraction/MenuGame.h" +#include "storm/abstraction/RefinementCommand.h" #include "storm/storage/expressions/Expression.h" @@ -23,10 +25,11 @@ namespace storm { /// Retrieves information about the abstraction. virtual AbstractionInformation const& getAbstractionInformation() const = 0; virtual storm::expressions::Expression const& getGuard(uint64_t player1Choice) const = 0; + virtual std::pair getPlayer1ChoiceRange() const = 0; virtual std::map getVariableUpdates(uint64_t player1Choice, uint64_t auxiliaryChoice) const = 0; /// Methods to refine the abstraction. - virtual void refine(std::vector const& predicates) = 0; + virtual void refine(std::vector const& commands) = 0; /// Exports a representation of the current abstraction state in the dot format. virtual void exportToDot(std::string const& filename, storm::dd::Bdd const& highlightStates, storm::dd::Bdd const& filter) const = 0; diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp index ca26fd8dd..ce8fc573d 100644 --- a/src/storm/abstraction/MenuGameRefiner.cpp +++ b/src/storm/abstraction/MenuGameRefiner.cpp @@ -13,12 +13,21 @@ namespace storm { template MenuGameRefiner::MenuGameRefiner(MenuGameAbstractor& abstractor, std::unique_ptr&& smtSolver) : abstractor(abstractor), splitPredicates(storm::settings::getModule().isSplitPredicatesSet()), splitter(), equivalenceChecker(std::move(smtSolver)) { - // Intentionally left empty. + + if (storm::settings::getModule().isAddAllGuardsSet()) { + std::vector guards; + + std::pair player1Choices = this->abstractor.get().getPlayer1ChoiceRange(); + for (uint64_t index = player1Choices.first; index < player1Choices.second; ++index) { + guards.push_back(this->abstractor.get().getGuard(index)); + } + performRefinement(createGlobalRefinement(guards)); + } } template void MenuGameRefiner::refine(std::vector const& predicates) const { - abstractor.get().refine(predicates); + performRefinement(createGlobalRefinement(predicates)); } template @@ -226,7 +235,8 @@ namespace storm { // Derive predicate based on the selected pivot state. storm::expressions::Expression newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - performRefinement({newPredicate}); + std::vector preparedPredicates = preprocessPredicates({newPredicate}); + performRefinement(createGlobalRefinement(preparedPredicates)); return true; } @@ -253,12 +263,13 @@ namespace storm { // Derive predicate based on the selected pivot state. storm::expressions::Expression newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - performRefinement({newPredicate}); + std::vector preparedPredicates = preprocessPredicates({newPredicate}); + performRefinement(createGlobalRefinement(preparedPredicates)); return true; } template - bool MenuGameRefiner::performRefinement(std::vector const& predicates) const { + std::vector MenuGameRefiner::preprocessPredicates(std::vector const& predicates) const { if (splitPredicates) { std::vector cleanedAtoms; @@ -292,13 +303,30 @@ namespace storm { } } - abstractor.get().refine(cleanedAtoms); + return cleanedAtoms; } else { // If no splitting of the predicates is required, just forward the refinement request to the abstractor. - abstractor.get().refine(predicates); } - return true; + return predicates; + } + + template + std::vector MenuGameRefiner::createGlobalRefinement(std::vector const& predicates) const { + std::vector commands; + + // std::pair player1Choices = abstractor.get().getPlayer1ChoiceRange(); + // for (uint64_t index = player1Choices.first; index < player1Choices.second; ++index) { + // commands.emplace_back(index, predicates); + // } + commands.emplace_back(predicates); + + return commands; + } + + template + void MenuGameRefiner::performRefinement(std::vector const& refinementCommands) const { + abstractor.get().refine(refinementCommands); } template class MenuGameRefiner; diff --git a/src/storm/abstraction/MenuGameRefiner.h b/src/storm/abstraction/MenuGameRefiner.h index 7394c14b3..a284966b9 100644 --- a/src/storm/abstraction/MenuGameRefiner.h +++ b/src/storm/abstraction/MenuGameRefiner.h @@ -4,6 +4,7 @@ #include #include +#include "storm/abstraction/RefinementCommand.h" #include "storm/abstraction/QualitativeResultMinMax.h" #include "storm/abstraction/QuantitativeResultMinMax.h" @@ -54,10 +55,18 @@ namespace storm { private: storm::expressions::Expression derivePredicateFromDifferingChoices(storm::dd::Bdd const& pivotState, storm::dd::Bdd const& player1Choice, storm::dd::Bdd const& lowerChoice, storm::dd::Bdd const& upperChoice) const; storm::expressions::Expression derivePredicateFromPivotState(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& pivotState, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const; + /*! - * Takes the given predicates, preprocesses them and then refines the abstractor. + * Preprocesses the predicates. */ - bool performRefinement(std::vector const& predicates) const; + std::vector preprocessPredicates(std::vector const& predicates) const; + + /*! + * Creates a set of refinement commands that amounts to splitting all player 1 choices with the given set of predicates. + */ + std::vector createGlobalRefinement(std::vector const& predicates) const; + + void performRefinement(std::vector const& refinementCommands) const; /// The underlying abstractor to refine. std::reference_wrapper> abstractor; diff --git a/src/storm/abstraction/RefinementCommand.cpp b/src/storm/abstraction/RefinementCommand.cpp new file mode 100644 index 000000000..db83cbb18 --- /dev/null +++ b/src/storm/abstraction/RefinementCommand.cpp @@ -0,0 +1,27 @@ +#include "storm/abstraction/RefinementCommand.h" + +namespace storm { + namespace abstraction { + + RefinementCommand::RefinementCommand(uint64_t referencedPlayer1Choice, std::vector const& predicates) : referencedPlayer1Choice(referencedPlayer1Choice), predicates(predicates) { + // Intentionally left empty. + } + + RefinementCommand::RefinementCommand(std::vector const& predicates) : predicates(predicates) { + // Intentionally left empty. + } + + bool RefinementCommand::refersToPlayer1Choice() const { + return static_cast(referencedPlayer1Choice); + } + + uint64_t RefinementCommand::getReferencedPlayer1Choice() const { + return referencedPlayer1Choice.get(); + } + + std::vector const& RefinementCommand::getPredicates() const { + return predicates; + } + + } +} diff --git a/src/storm/abstraction/RefinementCommand.h b/src/storm/abstraction/RefinementCommand.h new file mode 100644 index 000000000..dae7b0f2c --- /dev/null +++ b/src/storm/abstraction/RefinementCommand.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +#include + +#include "storm/storage/expressions/Expression.h" + +namespace storm { + namespace abstraction { + + class RefinementCommand { + public: + /*! + * Creates a new refinement command for the given player 1 choice. + */ + RefinementCommand(uint64_t referencedPlayer1Choice, std::vector const& predicates); + + /*! + * Creates a new refinement command for all player 1 choices. + */ + RefinementCommand(std::vector const& predicates); + + /// Access to the details of this refinement commands. + bool refersToPlayer1Choice() const; + uint64_t getReferencedPlayer1Choice() const; + std::vector const& getPredicates() const; + + private: + boost::optional referencedPlayer1Choice; + std::vector predicates; + }; + + } +} diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp index c52f972cd..f9b8479a0 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp @@ -17,6 +17,7 @@ #include "storm/utility/solver.h" #include "storm/exceptions/WrongFormatException.h" #include "storm/exceptions/InvalidArgumentException.h" +#include "storm/exceptions/NotSupportedException.h" #include "storm-config.h" #include "storm/adapters/CarlAdapter.h" @@ -29,9 +30,8 @@ namespace storm { template PrismMenuGameAbstractor::PrismMenuGameAbstractor(storm::prism::Program const& program, - std::shared_ptr const& smtSolverFactory, - bool addAllGuards) - : program(program), smtSolverFactory(smtSolverFactory), abstractionInformation(program.getManager()), modules(), initialStateAbstractor(abstractionInformation, program.getAllExpressionVariables(), {program.getInitialStatesExpression()}, this->smtSolverFactory), addedAllGuards(addAllGuards), currentGame(nullptr), refinementPerformed(false) { + std::shared_ptr const& smtSolverFactory) + : program(program), smtSolverFactory(smtSolverFactory), abstractionInformation(program.getManager(), smtSolverFactory->create(program.getManager())), modules(), initialStateAbstractor(abstractionInformation, program.getAllExpressionVariables(), {program.getInitialStatesExpression()}, this->smtSolverFactory), currentGame(nullptr), refinementPerformed(false) { // For now, we assume that there is a single module. If the program has more than one module, it needs // to be flattened before the procedure. @@ -52,9 +52,6 @@ namespace storm { for (auto const& module : program.getModules()) { // If we were requested to add all guards to the set of predicates, we do so now. for (auto const& command : module.getCommands()) { - if (addAllGuards) { - allGuards.push_back(command.getGuardExpression()); - } maximalUpdateCount = std::max(maximalUpdateCount, static_cast(command.getNumberOfUpdates())); } @@ -68,45 +65,40 @@ namespace storm { // For each module of the concrete program, we create an abstract counterpart. for (auto const& module : program.getModules()) { - this->modules.emplace_back(module, abstractionInformation, this->smtSolverFactory, addAllGuards); + this->modules.emplace_back(module, abstractionInformation, this->smtSolverFactory); } // Retrieve the command-update probability ADD, so we can multiply it with the abstraction BDD later. commandUpdateProbabilitiesAdd = modules.front().getCommandUpdateProbabilitiesAdd(); - - // Now that we have created all other DD variables, we create the DD variables for the predicates. - std::vector initialPredicates; - if (addAllGuards) { - for (auto const& guard : allGuards) { - initialPredicates.push_back(guard); - } + } + + template + void PrismMenuGameAbstractor::refine(std::vector const& commands) { + for (auto const& command : commands) { + STORM_LOG_THROW(!command.refersToPlayer1Choice(), storm::exceptions::NotSupportedException, "Currently only global refinement is supported."); + refine(command); + refinementPerformed |= !command.getPredicates().empty(); } - - // Finally, refine using the all predicates and build game as a by-product. - this->refine(initialPredicates); } template - void PrismMenuGameAbstractor::refine(std::vector const& predicates) { + void PrismMenuGameAbstractor::refine(RefinementCommand const& command) { // Add the predicates to the global list of predicates and gather their indices. - std::vector newPredicateIndices; - for (auto const& predicate : predicates) { + std::vector predicateIndices; + for (auto const& predicate : command.getPredicates()) { STORM_LOG_THROW(predicate.hasBooleanType(), storm::exceptions::InvalidArgumentException, "Expecting a predicate of type bool."); - newPredicateIndices.push_back(abstractionInformation.addPredicate(predicate)); + predicateIndices.push_back(abstractionInformation.getOrAddPredicate(predicate)); } // Refine all abstract modules. for (auto& module : modules) { - module.refine(newPredicateIndices); + module.refine(predicateIndices); } // Refine initial state abstractor. - initialStateAbstractor.refine(newPredicateIndices); - - // Update the flag that stores whether a refinement was performed. - refinementPerformed = refinementPerformed || !newPredicateIndices.empty(); + initialStateAbstractor.refine(predicateIndices); } - + template MenuGame PrismMenuGameAbstractor::abstract() { if (refinementPerformed) { @@ -131,6 +123,11 @@ namespace storm { return modules.front().getVariableUpdates(player1Choice, auxiliaryChoice); } + template + std::pair PrismMenuGameAbstractor::getPlayer1ChoiceRange() const { + return std::make_pair(0, modules.front().getCommands().size()); + } + template storm::dd::Bdd PrismMenuGameAbstractor::getStates(storm::expressions::Expression const& predicate) { STORM_LOG_ASSERT(currentGame != nullptr, "Game was not properly created."); diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h index 9d64d3eac..bef998346 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h @@ -5,6 +5,7 @@ #include "storm/abstraction/MenuGameAbstractor.h" #include "storm/abstraction/AbstractionInformation.h" #include "storm/abstraction/MenuGame.h" +#include "storm/abstraction/RefinementCommand.h" #include "storm/abstraction/prism/ModuleAbstractor.h" #include "storm/storage/dd/Add.h" @@ -42,9 +43,8 @@ namespace storm { * @param expressionManager The manager responsible for the expressions of the program. * @param program The concrete program for which to build the abstraction. * @param smtSolverFactory A factory that is to be used for creating new SMT solvers. - * @param addAllGuards A flag that indicates whether all guards of the program should be added to the initial set of predicates. */ - PrismMenuGameAbstractor(storm::prism::Program const& program, std::shared_ptr const& smtSolverFactory = std::make_shared(), bool addAllGuards = false); + PrismMenuGameAbstractor(storm::prism::Program const& program, std::shared_ptr const& smtSolverFactory = std::make_shared()); PrismMenuGameAbstractor(PrismMenuGameAbstractor const&) = default; PrismMenuGameAbstractor& operator=(PrismMenuGameAbstractor const&) = default; @@ -56,14 +56,14 @@ namespace storm { * * @return The abstract stochastic two player game. */ - MenuGame abstract(); + MenuGame abstract() override; /*! * Retrieves information about the abstraction. * * @return The abstraction information object. */ - AbstractionInformation const& getAbstractionInformation() const; + AbstractionInformation const& getAbstractionInformation() const override; /*! * Retrieves the guard predicate of the given player 1 choice. @@ -71,13 +71,18 @@ namespace storm { * @param player1Choice The choice for which to retrieve the guard. * @return The guard of the player 1 choice. */ - storm::expressions::Expression const& getGuard(uint64_t player1Choice) const; + storm::expressions::Expression const& getGuard(uint64_t player1Choice) const override; /*! * Retrieves a mapping from variables to expressions that define their updates wrt. to the given player * 1 choice and auxiliary choice. */ - std::map getVariableUpdates(uint64_t player1Choice, uint64_t auxiliaryChoice) const; + std::map getVariableUpdates(uint64_t player1Choice, uint64_t auxiliaryChoice) const override; + + /*! + * Retrieves the range of player 1 choices. + */ + std::pair getPlayer1ChoiceRange() const override; /*! * Retrieves the set of states (represented by a BDD) satisfying the given predicate, assuming that it @@ -89,12 +94,19 @@ namespace storm { storm::dd::Bdd getStates(storm::expressions::Expression const& predicate); /*! - * Refines the abstract program with the given predicates. + * Performs the given refinement commands. * - * @param predicates The new predicates. + * @param commands The commands to perform. */ - void refine(std::vector const& predicates); - + virtual void refine(std::vector const& commands) override; + + /*! + * Performs the given refinement command. + * + * @param command The command to perform. + */ + void refine(RefinementCommand const& command); + /*! * Exports the current state of the abstraction in the dot format to the given file. * @@ -102,7 +114,7 @@ namespace storm { * @param highlightStates A BDD characterizing states that will be highlighted. * @param filter A filter that is applied to select which part of the game to export. */ - void exportToDot(std::string const& filename, storm::dd::Bdd const& highlightStates, storm::dd::Bdd const& filter) const; + void exportToDot(std::string const& filename, storm::dd::Bdd const& highlightStates, storm::dd::Bdd const& filter) const override; private: /*!