diff --git a/src/storm/abstraction/AbstractionInformation.cpp b/src/storm/abstraction/AbstractionInformation.cpp index 9c2250dc5..4bee18228 100644 --- a/src/storm/abstraction/AbstractionInformation.cpp +++ b/src/storm/abstraction/AbstractionInformation.cpp @@ -14,7 +14,7 @@ namespace storm { namespace abstraction { template - AbstractionInformation::AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager) : expressionManager(expressionManager), equivalenceChecker(std::move(smtSolver)), ddManager(ddManager), allPredicateIdentities(ddManager->getBddOne()) { + AbstractionInformation::AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::set const& allVariables, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager) : expressionManager(expressionManager), equivalenceChecker(std::move(smtSolver)), variables(allVariables), ddManager(ddManager), allPredicateIdentities(ddManager->getBddOne()), expressionToBddMap() { // Intentionally left empty. } @@ -45,6 +45,7 @@ namespace storm { for (uint64_t index = 0; index < predicates.size(); ++index) { auto const& oldPredicate = predicates[index]; if (equivalenceChecker.areEquivalent(oldPredicate, predicate)) { + expressionToBddMap[predicate] = expressionToBddMap.at(oldPredicate); return index; } } @@ -70,6 +71,8 @@ namespace storm { orderedSourceVariables.push_back(newMetaVariable.first); orderedSuccessorVariables.push_back(newMetaVariable.second); ddVariableIndexToPredicateIndexMap[predicateIdentities.back().getIndex()] = predicateIndex; + expressionToBddMap[predicate] = predicateBdds[predicateIndex].first && !bottomStateBdds.first; + return predicateIndex; } @@ -311,14 +314,8 @@ namespace storm { } template - std::map> AbstractionInformation::getPredicateToBddMap() const { - std::map> result; - - for (uint_fast64_t index = 0; index < predicates.size(); ++index) { - result[predicates[index]] = predicateBdds[index].first && !bottomStateBdds.first; - } - - return result; + std::map> const& AbstractionInformation::getPredicateToBddMap() const { + return expressionToBddMap; } template @@ -446,8 +443,8 @@ namespace storm { } template - std::map AbstractionInformation::decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const { - std::map result; + std::map> AbstractionInformation::decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const { + std::map> result; storm::dd::Add lowerChoiceAsAdd = choice.template toAdd(); for (auto const& successorValuePair : lowerChoiceAsAdd) { @@ -461,7 +458,7 @@ namespace storm { } } - result[updateIndex] = successor; + result[updateIndex] = std::make_pair(successor, successorValuePair.second); } return result; } diff --git a/src/storm/abstraction/AbstractionInformation.h b/src/storm/abstraction/AbstractionInformation.h index 4ba2e4a70..9048f1614 100644 --- a/src/storm/abstraction/AbstractionInformation.h +++ b/src/storm/abstraction/AbstractionInformation.h @@ -34,10 +34,11 @@ namespace storm { * Creates a new abstraction information object. * * @param expressionManager The manager responsible for all variables and expressions during the abstraction process. + * @param allVariables All expression variables that can appear in predicates known to this object. * @param smtSolver An SMT solver that is used to detect equivalent predicates. * @param ddManager The manager responsible for the DDs. */ - AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager = std::make_shared>()); + AbstractionInformation(storm::expressions::ExpressionManager& expressionManager, std::set const& allVariables, std::unique_ptr&& smtSolver, std::shared_ptr> ddManager = std::make_shared>()); /*! * Adds the given variable. @@ -372,7 +373,7 @@ namespace storm { * * @return A mapping from predicates to their representing BDDs. */ - std::map> getPredicateToBddMap() const; + std::map> const& getPredicateToBddMap() const; /*! * Retrieves the meta variables pairs for all predicates. @@ -455,7 +456,7 @@ namespace storm { /*! * Decodes the choice in the form of a BDD over the destination variables. */ - std::map decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const; + std::map> decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const; /*! * Decodes the given BDD (over source, player 1 and aux variables) into a bit vector indicating the truth @@ -564,6 +565,9 @@ namespace storm { /// The BDDs associated with the meta variables encoding auxiliary information. std::vector> auxVariableBdds; + + /// A mapping from expressions to the corresponding BDDs. + std::map> expressionToBddMap; }; } diff --git a/src/storm/abstraction/LocalExpressionInformation.cpp b/src/storm/abstraction/LocalExpressionInformation.cpp index 6424c55c1..98acb1558 100644 --- a/src/storm/abstraction/LocalExpressionInformation.cpp +++ b/src/storm/abstraction/LocalExpressionInformation.cpp @@ -1,5 +1,7 @@ #include "storm/abstraction/LocalExpressionInformation.h" +#include "storm/abstraction/AbstractionInformation.h" + #include #include "storm/utility/macros.h" @@ -7,7 +9,8 @@ namespace storm { namespace abstraction { - LocalExpressionInformation::LocalExpressionInformation(std::set const& relevantVariables, std::vector> const& expressionIndexPairs) : relevantVariables(relevantVariables), expressionBlocks(relevantVariables.size()) { + template + LocalExpressionInformation::LocalExpressionInformation(AbstractionInformation const& abstractionInformation) : relevantVariables(abstractionInformation.getVariables()), expressionBlocks(relevantVariables.size()), abstractionInformation(abstractionInformation) { // Assign each variable to a new block. uint_fast64_t currentBlock = 0; variableBlocks.resize(relevantVariables.size()); @@ -17,40 +20,39 @@ namespace storm { variableBlocks[currentBlock].insert(variable); ++currentBlock; } - - // Add all expressions, which might relate some variables. - for (auto const& expressionIndexPair : expressionIndexPairs) { - this->addExpression(expressionIndexPair.first, expressionIndexPair.second); - } } - bool LocalExpressionInformation::addExpression(storm::expressions::Expression const& expression, uint_fast64_t globalExpressionIndex) { + template + bool LocalExpressionInformation::addExpression(uint_fast64_t globalExpressionIndex) { + storm::expressions::Expression const& expression = abstractionInformation.get().getPredicateByIndex(globalExpressionIndex); + // Register the expression for all variables that appear in it. std::set expressionVariables = expression.getVariables(); for (auto const& variable : expressionVariables) { - variableToExpressionsMapping[variable].insert(this->expressions.size()); + variableToExpressionsMapping[variable].insert(globalExpressionIndex); } // Add the expression to the block of the first variable. When relating the variables, the blocks will // get merged (if necessary). STORM_LOG_ASSERT(!expressionVariables.empty(), "Found no variables in expression."); - expressionBlocks[getBlockIndexOfVariable(*expressionVariables.begin())].insert(this->expressions.size()); + expressionBlocks[getBlockIndexOfVariable(*expressionVariables.begin())].insert(globalExpressionIndex); // Add expression and relate all the appearing variables. - this->globalToLocalIndexMapping[globalExpressionIndex] = this->expressions.size(); - this->expressions.push_back(expression); return this->relate(expressionVariables); } - bool LocalExpressionInformation::areRelated(storm::expressions::Variable const& firstVariable, storm::expressions::Variable const& secondVariable) { + template + bool LocalExpressionInformation::areRelated(storm::expressions::Variable const& firstVariable, storm::expressions::Variable const& secondVariable) { return getBlockIndexOfVariable(firstVariable) == getBlockIndexOfVariable(secondVariable); } - bool LocalExpressionInformation::relate(storm::expressions::Variable const& firstVariable, storm::expressions::Variable const& secondVariable) { + template + bool LocalExpressionInformation::relate(storm::expressions::Variable const& firstVariable, storm::expressions::Variable const& secondVariable) { return this->relate({firstVariable, secondVariable}); } - bool LocalExpressionInformation::relate(std::set const& variables) { + template + bool LocalExpressionInformation::relate(std::set const& variables) { // Determine all blocks that need to be merged. std::set blocksToMerge; for (auto const& variable : variables) { @@ -68,7 +70,8 @@ namespace storm { return true; } - void LocalExpressionInformation::mergeBlocks(std::set const& blocksToMerge) { + template + void LocalExpressionInformation::mergeBlocks(std::set const& blocksToMerge) { // Merge all blocks into the block to keep. std::vector> newVariableBlocks; std::vector> newExpressionBlocks; @@ -108,28 +111,34 @@ namespace storm { expressionBlocks = std::move(newExpressionBlocks); } - std::set const& LocalExpressionInformation::getBlockOfVariable(storm::expressions::Variable const& variable) const { + template + std::set const& LocalExpressionInformation::getBlockOfVariable(storm::expressions::Variable const& variable) const { return variableBlocks[getBlockIndexOfVariable(variable)]; } - uint_fast64_t LocalExpressionInformation::getNumberOfBlocks() const { + template + uint_fast64_t LocalExpressionInformation::getNumberOfBlocks() const { return this->variableBlocks.size(); } - std::set const& LocalExpressionInformation::getVariableBlockWithIndex(uint_fast64_t blockIndex) const { + template + std::set const& LocalExpressionInformation::getVariableBlockWithIndex(uint_fast64_t blockIndex) const { return this->variableBlocks[blockIndex]; } - uint_fast64_t LocalExpressionInformation::getBlockIndexOfVariable(storm::expressions::Variable const& variable) const { + template + uint_fast64_t LocalExpressionInformation::getBlockIndexOfVariable(storm::expressions::Variable const& variable) const { STORM_LOG_ASSERT(this->relevantVariables.find(variable) != this->relevantVariables.end(), "Illegal variable '" << variable.getName() << "' for partition."); return this->variableToBlockMapping.find(variable)->second; } - std::set const& LocalExpressionInformation::getRelatedExpressions(storm::expressions::Variable const& variable) const { + template + std::set const& LocalExpressionInformation::getRelatedExpressions(storm::expressions::Variable const& variable) const { return this->expressionBlocks[getBlockIndexOfVariable(variable)]; } - std::set LocalExpressionInformation::getRelatedExpressions(std::set const& variables) const { + template + std::set LocalExpressionInformation::getRelatedExpressions(std::set const& variables) const { // Start by determining the indices of all expression blocks that are related to any of the variables. std::set relatedExpressionBlockIndices; for (auto const& variable : variables) { @@ -144,12 +153,14 @@ namespace storm { return result; } - std::set const& LocalExpressionInformation::getExpressionsUsingVariable(storm::expressions::Variable const& variable) const { + template + std::set const& LocalExpressionInformation::getExpressionsUsingVariable(storm::expressions::Variable const& variable) const { STORM_LOG_ASSERT(this->relevantVariables.find(variable) != this->relevantVariables.end(), "Illegal variable '" << variable.getName() << "' for partition."); return this->variableToExpressionsMapping.find(variable)->second; } - std::set LocalExpressionInformation::getExpressionsUsingVariables(std::set const& variables) const { + template + std::set LocalExpressionInformation::getExpressionsUsingVariables(std::set const& variables) const { std::set result; for (auto const& variable : variables) { @@ -161,11 +172,8 @@ namespace storm { return result; } - storm::expressions::Expression const& LocalExpressionInformation::getExpression(uint_fast64_t expressionIndex) const { - return this->expressions[expressionIndex]; - } - - std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition) { + template + std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition) { std::vector blocks; for (uint_fast64_t index = 0; index < partition.variableBlocks.size(); ++index) { auto const& variableBlock = partition.variableBlocks[index]; @@ -177,9 +185,9 @@ namespace storm { } std::vector expressionsInBlock; - for (auto const& expression : expressionBlock) { + for (auto const& expressionIndex : expressionBlock) { std::stringstream stream; - stream << partition.expressions[expression]; + stream << partition.abstractionInformation.get().getPredicateByIndex(expressionIndex); expressionsInBlock.push_back(stream.str()); } @@ -191,6 +199,11 @@ namespace storm { out << "}"; return out; } - + + template class LocalExpressionInformation; + template class LocalExpressionInformation; + + template std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); + template std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); } } diff --git a/src/storm/abstraction/LocalExpressionInformation.h b/src/storm/abstraction/LocalExpressionInformation.h index 391ed32aa..37c0a3b68 100644 --- a/src/storm/abstraction/LocalExpressionInformation.h +++ b/src/storm/abstraction/LocalExpressionInformation.h @@ -8,27 +8,31 @@ #include "storm/storage/expressions/Variable.h" #include "storm/storage/expressions/Expression.h" +#include "storm/storage/dd/DdType.h" + namespace storm { namespace abstraction { + template + class AbstractionInformation; + + template class LocalExpressionInformation { public: /*! * Constructs a new variable partition. * - * @param relevantVariables The variables of this partition. - * @param expressionIndexPairs The (initial) pairs of expressions and their global indices. + * @param abstractionInformation The object storing global information about the abstraction. */ - LocalExpressionInformation(std::set const& relevantVariables, std::vector> const& expressionIndexPairs = {}); + LocalExpressionInformation(AbstractionInformation const& abstractionInformation); /*! * Adds the expression and therefore indirectly may cause blocks of variables to be merged. * - * @param expression The expression to add. * @param globalExpressionIndex The global index of the expression. * @return True iff the partition changed. */ - bool addExpression(storm::expressions::Expression const& expression, uint_fast64_t globalExpressionIndex); + bool addExpression(uint_fast64_t globalExpressionIndex); /*! * Retrieves whether the two given variables are in the same block of the partition. @@ -119,15 +123,8 @@ namespace storm { */ std::set getExpressionsUsingVariables(std::set const& variables) const; - /*! - * Retrieves the expression with the given index. - * - * @param expressionIndex The index of the expression to retrieve. - * @return The corresponding expression. - */ - storm::expressions::Expression const& getExpression(uint_fast64_t expressionIndex) const; - - friend std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); + template + friend std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); private: /*! @@ -152,14 +149,12 @@ namespace storm { // A mapping from variables to the indices of all expressions they appear in. std::unordered_map> variableToExpressionsMapping; - // A mapping from global expression indices to local ones. - std::unordered_map globalToLocalIndexMapping; - - // The vector of all expressions. - std::vector expressions; + // The object storing the abstraction information. + std::reference_wrapper const> abstractionInformation; }; - std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); + template + std::ostream& operator<<(std::ostream& out, LocalExpressionInformation const& partition); } } diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp index a062a33c2..010fdb9b8 100644 --- a/src/storm/abstraction/MenuGameRefiner.cpp +++ b/src/storm/abstraction/MenuGameRefiner.cpp @@ -195,30 +195,85 @@ namespace storm { STORM_LOG_TRACE("No bottom state successor. Deriving a new predicate using weakest precondition."); // Decode both choices to explicit mappings. - std::map lowerChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(lowerChoice); - std::map upperChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(upperChoice); + std::map> lowerChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(lowerChoice); + std::map> upperChoiceUpdateToSuccessorMapping = abstractionInformation.decodeChoiceToUpdateSuccessorMapping(upperChoice); STORM_LOG_ASSERT(lowerChoiceUpdateToSuccessorMapping.size() == upperChoiceUpdateToSuccessorMapping.size(), "Mismatching sizes after decode (" << lowerChoiceUpdateToSuccessorMapping.size() << " vs. " << upperChoiceUpdateToSuccessorMapping.size() << ")."); - // Now go through the mappings and find points of deviation. Currently, we take the first deviation. - auto lowerIt = lowerChoiceUpdateToSuccessorMapping.begin(); - auto lowerIte = lowerChoiceUpdateToSuccessorMapping.end(); - auto upperIt = upperChoiceUpdateToSuccessorMapping.begin(); - for (; lowerIt != lowerIte; ++lowerIt, ++upperIt) { - STORM_LOG_ASSERT(lowerIt->first == upperIt->first, "Update indices mismatch."); - uint_fast64_t updateIndex = lowerIt->first; - bool deviates = lowerIt->second != upperIt->second; + // First, sort updates according to probability mass. + std::vector> updateIndicesAndMasses; + for (auto const& entry : lowerChoiceUpdateToSuccessorMapping) { + updateIndicesAndMasses.emplace_back(entry.first, entry.second.second); + } + std::sort(updateIndicesAndMasses.begin(), updateIndicesAndMasses.end(), [] (std::pair const& a, std::pair const& b) { return a.second > b.second; }); + + // Now find the update with the highest probability mass among all deviating updates. More specifically, + // we determine the set of predicate indices for which there is a deviation. + std::set deviationPredicates; + uint64_t orderedUpdateIndex = 0; + std::vector possibleRefinementPredicates; + for (; orderedUpdateIndex < updateIndicesAndMasses.size(); ++orderedUpdateIndex) { + storm::storage::BitVector const& lower = lowerChoiceUpdateToSuccessorMapping.at(updateIndicesAndMasses[orderedUpdateIndex].first).first; + storm::storage::BitVector const& upper = upperChoiceUpdateToSuccessorMapping.at(updateIndicesAndMasses[orderedUpdateIndex].first).first; + + bool deviates = lower != upper; if (deviates) { - for (uint_fast64_t predicateIndex = 0; predicateIndex < lowerIt->second.size(); ++predicateIndex) { - if (lowerIt->second.get(predicateIndex) != upperIt->second.get(predicateIndex)) { - // Now we know the point of the deviation (command, update, predicate). - newPredicate = abstractionInformation.getPredicateByIndex(predicateIndex).substitute(abstractor.get().getVariableUpdates(player1Index, updateIndex)).simplify(); - break; + std::map variableUpdates = abstractor.get().getVariableUpdates(player1Index, updateIndicesAndMasses[orderedUpdateIndex].first); + + for (uint64_t predicateIndex = 0; predicateIndex < lower.size(); ++predicateIndex) { + if (lower[predicateIndex] != upper[predicateIndex]) { + possibleRefinementPredicates.push_back(abstractionInformation.getPredicateByIndex(predicateIndex).substitute(variableUpdates).simplify()); + } + } + ++orderedUpdateIndex; + break; + } + } + + STORM_LOG_ASSERT(!possibleRefinementPredicates.empty(), "Expected refinement predicates."); + + // Since we can choose any of the deviation predicates to perform the split, we go through the remaining + // updates and build all deviation predicates. We can then check whether any of the possible refinement + // predicates also eliminates another deviation. + std::vector otherRefinementPredicates; + for (; orderedUpdateIndex < updateIndicesAndMasses.size(); ++orderedUpdateIndex) { + storm::storage::BitVector const& lower = lowerChoiceUpdateToSuccessorMapping.at(updateIndicesAndMasses[orderedUpdateIndex].first).first; + storm::storage::BitVector const& upper = upperChoiceUpdateToSuccessorMapping.at(updateIndicesAndMasses[orderedUpdateIndex].first).first; + + bool deviates = lower != upper; + if (deviates) { + std::map newVariableUpdates = abstractor.get().getVariableUpdates(player1Index, updateIndicesAndMasses[orderedUpdateIndex].first); + for (uint64_t predicateIndex = 0; predicateIndex < lower.size(); ++predicateIndex) { + if (lower[predicateIndex] != upper[predicateIndex]) { + otherRefinementPredicates.push_back(abstractionInformation.getPredicateByIndex(predicateIndex).substitute(newVariableUpdates).simplify()); } } } } + + // Finally, go through the refinement predicates and see how many deviations they cover. + std::vector refinementPredicateIndexToCount(possibleRefinementPredicates.size(), 0); + for (uint64_t index = 0; index < possibleRefinementPredicates.size(); ++index) { + refinementPredicateIndexToCount[index] = 1; + } + for (auto const& otherPredicate : otherRefinementPredicates) { + for (uint64_t index = 0; index < possibleRefinementPredicates.size(); ++index) { + if (equivalenceChecker.areEquivalent(otherPredicate, possibleRefinementPredicates[index])) { + ++refinementPredicateIndexToCount[index]; + } + } + } + + // Find predicate that covers the most deviations. + uint64_t chosenPredicateIndex = 0; + for (uint64_t index = 0; index < possibleRefinementPredicates.size(); ++index) { + if (refinementPredicateIndexToCount[index] > refinementPredicateIndexToCount[chosenPredicateIndex]) { + chosenPredicateIndex = index; + } + } + newPredicate = possibleRefinementPredicates[chosenPredicateIndex]; + STORM_LOG_ASSERT(newPredicate.isInitialized(), "Could not derive new predicate as there is no deviation."); - STORM_LOG_DEBUG("Derived new predicate (based on weakest-precondition): " << newPredicate); + STORM_LOG_DEBUG("Derived new predicate (based on weakest-precondition): " << newPredicate << ", (equivalent to " << (refinementPredicateIndexToCount[chosenPredicateIndex] - 1) << " other refinement predicates)"); } return RefinementPredicates(fromGuard ? RefinementPredicates::Source::Guard : RefinementPredicates::Source::WeakestPrecondition, {newPredicate}); @@ -240,14 +295,16 @@ namespace storm { // Then constrain these states by the requirement that for either the lower or upper player 1 choice the player 2 choices must be different and // that the difference is not because of a missing strategy in either case. - // Start with constructing the player 2 states that have a prob 0 (min) and prob 1 (max) strategy. + // Start with constructing the player 2 states that have a min and a max strategy. storm::dd::Bdd constraint = minPlayer2Strategy.existsAbstract(game.getPlayer2Variables()) && maxPlayer2Strategy.existsAbstract(game.getPlayer2Variables()); // Now construct all player 2 choices that actually exist and differ in the min and max case. constraint &= minPlayer2Strategy.exclusiveOr(maxPlayer2Strategy); + minPlayer2Strategy.exclusiveOr(maxPlayer2Strategy).template toAdd().exportToDot("pl2diff.dot"); + constraint.template toAdd().exportToDot("constraint.dot"); // Then restrict the pivot states by requiring existing and different player 2 choices. - result.pivotStates &= ((minPlayer1Strategy && maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables()); + result.pivotStates &= ((minPlayer1Strategy || maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables()); return result; } diff --git a/src/storm/abstraction/MenuGameRefiner.h b/src/storm/abstraction/MenuGameRefiner.h index 84745aa4a..2894c038f 100644 --- a/src/storm/abstraction/MenuGameRefiner.h +++ b/src/storm/abstraction/MenuGameRefiner.h @@ -62,7 +62,9 @@ namespace storm { MenuGameRefiner(MenuGameAbstractor& abstractor, std::unique_ptr&& smtSolver); /*! - * Refines the abstractor with the given set of predicates. + * Refines the abstractor with the given predicates. + * + * @param predicates The predicates to use for refinement. */ void refine(std::vector const& predicates) const; diff --git a/src/storm/abstraction/StateSetAbstractor.cpp b/src/storm/abstraction/StateSetAbstractor.cpp index 50da8f008..456d89c5e 100644 --- a/src/storm/abstraction/StateSetAbstractor.cpp +++ b/src/storm/abstraction/StateSetAbstractor.cpp @@ -14,7 +14,7 @@ namespace storm { namespace abstraction { template - StateSetAbstractor::StateSetAbstractor(AbstractionInformation& abstractionInformation, std::set const& allVariables, std::vector const& statePredicates, std::shared_ptr const& smtSolverFactory) : smtSolver(smtSolverFactory->create(abstractionInformation.getExpressionManager())), abstractionInformation(abstractionInformation), localExpressionInformation(allVariables), relevantPredicatesAndVariables(), concretePredicateVariables(), needsRecomputation(false), cachedBdd(abstractionInformation.getDdManager().getBddOne()), constraint(abstractionInformation.getDdManager().getBddOne()) { + StateSetAbstractor::StateSetAbstractor(AbstractionInformation& abstractionInformation, std::vector const& statePredicates, std::shared_ptr const& smtSolverFactory) : smtSolver(smtSolverFactory->create(abstractionInformation.getExpressionManager())), abstractionInformation(abstractionInformation), localExpressionInformation(abstractionInformation), relevantPredicatesAndVariables(), concretePredicateVariables(), needsRecomputation(false), cachedBdd(abstractionInformation.getDdManager().getBddOne()), constraint(abstractionInformation.getDdManager().getBddOne()) { // Assert all state predicates. for (auto const& predicate : statePredicates) { @@ -44,7 +44,7 @@ namespace storm { void StateSetAbstractor::refine(std::vector const& newPredicates) { // Make the partition aware of the new predicates, which may make more predicates relevant to the abstraction. for (auto const& predicateIndex : newPredicates) { - localExpressionInformation.addExpression(this->getAbstractionInformation().getPredicateByIndex(predicateIndex), predicateIndex); + localExpressionInformation.addExpression(predicateIndex); } needsRecomputation = true; } diff --git a/src/storm/abstraction/StateSetAbstractor.h b/src/storm/abstraction/StateSetAbstractor.h index 8a3f500c5..9352f2d58 100644 --- a/src/storm/abstraction/StateSetAbstractor.h +++ b/src/storm/abstraction/StateSetAbstractor.h @@ -47,12 +47,11 @@ namespace storm { * Creates a state set abstractor. * * @param abstractionInformation An object storing information about the abstraction such as predicates and BDDs. - * @param allVariables All variables that appear in the predicates. * @param statePredicates A set of predicates that have to hold in the concrete states this abstractor is * supposed to abstract. * @param smtSolverFactory A factory that can create new SMT solvers. */ - StateSetAbstractor(AbstractionInformation& abstractionInformation, std::set const& allVariables, std::vector const& statePredicates, std::shared_ptr const& smtSolverFactory = std::make_shared()); + StateSetAbstractor(AbstractionInformation& abstractionInformation, std::vector const& statePredicates, std::shared_ptr const& smtSolverFactory = std::make_shared()); /*! * Refines the abstractor by making the given predicates new abstract predicates. @@ -135,7 +134,7 @@ namespace storm { std::reference_wrapper> abstractionInformation; // The local expression-related information. - LocalExpressionInformation localExpressionInformation; + LocalExpressionInformation localExpressionInformation; // The set of relevant predicates and the corresponding decision variables. std::vector> relevantPredicatesAndVariables; diff --git a/src/storm/abstraction/prism/CommandAbstractor.cpp b/src/storm/abstraction/prism/CommandAbstractor.cpp index 605bd0e9b..9d21f1c86 100644 --- a/src/storm/abstraction/prism/CommandAbstractor.cpp +++ b/src/storm/abstraction/prism/CommandAbstractor.cpp @@ -23,7 +23,7 @@ namespace storm { namespace abstraction { namespace prism { template - CommandAbstractor::CommandAbstractor(storm::prism::Command const& command, AbstractionInformation& abstractionInformation, std::shared_ptr const& smtSolverFactory, bool guardIsPredicate) : smtSolver(smtSolverFactory->create(abstractionInformation.getExpressionManager())), abstractionInformation(abstractionInformation), command(command), localExpressionInformation(abstractionInformation.getVariables()), evaluator(abstractionInformation.getExpressionManager()), relevantPredicatesAndVariables(), cachedDd(abstractionInformation.getDdManager().getBddZero(), 0), decisionVariables(), skipBottomStates(false), forceRecomputation(true), abstractGuard(abstractionInformation.getDdManager().getBddZero()), bottomStateAbstractor(abstractionInformation, abstractionInformation.getExpressionVariables(), {!command.getGuardExpression()}, smtSolverFactory) { + CommandAbstractor::CommandAbstractor(storm::prism::Command const& command, AbstractionInformation& abstractionInformation, std::shared_ptr const& smtSolverFactory, bool guardIsPredicate) : smtSolver(smtSolverFactory->create(abstractionInformation.getExpressionManager())), abstractionInformation(abstractionInformation), command(command), localExpressionInformation(abstractionInformation), evaluator(abstractionInformation.getExpressionManager()), relevantPredicatesAndVariables(), cachedDd(abstractionInformation.getDdManager().getBddZero(), 0), decisionVariables(), skipBottomStates(false), forceRecomputation(true), abstractGuard(abstractionInformation.getDdManager().getBddZero()), bottomStateAbstractor(abstractionInformation, {!command.getGuardExpression()}, smtSolverFactory) { // Make the second component of relevant predicates have the right size. relevantPredicatesAndVariables.second.resize(command.getNumberOfUpdates()); @@ -42,7 +42,7 @@ namespace storm { void CommandAbstractor::refine(std::vector const& predicates) { // Add all predicates to the variable partition. for (auto predicateIndex : predicates) { - localExpressionInformation.addExpression(this->getAbstractionInformation().getPredicateByIndex(predicateIndex), predicateIndex); + localExpressionInformation.addExpression(predicateIndex); } STORM_LOG_TRACE("Current variable partition is: " << localExpressionInformation); diff --git a/src/storm/abstraction/prism/CommandAbstractor.h b/src/storm/abstraction/prism/CommandAbstractor.h index dbf2b862d..0bacfbba8 100644 --- a/src/storm/abstraction/prism/CommandAbstractor.h +++ b/src/storm/abstraction/prism/CommandAbstractor.h @@ -212,7 +212,7 @@ namespace storm { std::reference_wrapper command; // The local expression-related information. - LocalExpressionInformation localExpressionInformation; + LocalExpressionInformation localExpressionInformation; // The evaluator used to translate the probability expressions. storm::expressions::ExpressionEvaluator evaluator; diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp index 552e00486..ccb8d35fc 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp @@ -31,16 +31,13 @@ namespace storm { template PrismMenuGameAbstractor::PrismMenuGameAbstractor(storm::prism::Program const& program, std::shared_ptr const& smtSolverFactory) - : program(program), smtSolverFactory(smtSolverFactory), abstractionInformation(program.getManager(), smtSolverFactory->create(program.getManager())), modules(), initialStateAbstractor(abstractionInformation, program.getAllExpressionVariables(), {program.getInitialStatesExpression()}, this->smtSolverFactory), currentGame(nullptr), refinementPerformed(false) { + : program(program), smtSolverFactory(smtSolverFactory), abstractionInformation(program.getManager(), program.getAllExpressionVariables(), smtSolverFactory->create(program.getManager())), modules(), initialStateAbstractor(abstractionInformation, {program.getInitialStatesExpression()}, this->smtSolverFactory), currentGame(nullptr), refinementPerformed(false) { // For now, we assume that there is a single module. If the program has more than one module, it needs // to be flattened before the procedure. STORM_LOG_THROW(program.getNumberOfModules() == 1, storm::exceptions::WrongFormatException, "Cannot create abstract program from program containing too many modules."); - // Add all variables and range expressions to the information object. - for (auto const& variable : this->program.get().getAllExpressionVariables()) { - abstractionInformation.addExpressionVariable(variable); - } + // Add all variables range expressions to the information object. for (auto const& range : this->program.get().getAllRangeExpressions()) { abstractionInformation.addConstraint(range); initialStateAbstractor.constrain(range); @@ -162,11 +159,8 @@ namespace storm { // Compute bottom states and the appropriate transitions if necessary. BottomStateResult bottomStateResult(abstractionInformation.getDdManager().getBddZero(), abstractionInformation.getDdManager().getBddZero()); - bool hasBottomStates = false; - if (!addedAllGuards) { - bottomStateResult = modules.front().getBottomStateTransitions(reachableStates, game.numberOfPlayer2Variables); - hasBottomStates = !bottomStateResult.states.isZero(); - } + bottomStateResult = modules.front().getBottomStateTransitions(reachableStates, game.numberOfPlayer2Variables); + bool hasBottomStates = !bottomStateResult.states.isZero(); // Construct the transition matrix by cutting away the transitions of unreachable states. storm::dd::Add transitionMatrix = (game.bdd && reachableStates).template toAdd(); diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h index ae7bd5550..48be5968b 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h @@ -145,9 +145,6 @@ namespace storm { // A state-set abstractor used to determine the initial states of the abstraction. StateSetAbstractor initialStateAbstractor; - // A flag that stores whether all guards were added (which is relevant for determining the bottom states). - bool addedAllGuards; - // An ADD characterizing the probabilities of commands and their updates. storm::dd::Add commandUpdateProbabilitiesAdd; diff --git a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp index 04c4ac82a..49b0af65f 100644 --- a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp +++ b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.cpp @@ -18,6 +18,10 @@ #include "storm/solver/SymbolicGameSolver.h" +#include "storm/settings/SettingsManager.h" +#include "storm/settings/modules/CoreSettings.h" +#include "storm/settings/modules/AbstractionSettings.h" + #include "storm/utility/solver.h" #include "storm/utility/prism.h" #include "storm/utility/macros.h" @@ -37,7 +41,7 @@ namespace storm { using storm::abstraction::QuantitativeResultMinMax; template - GameBasedMdpModelChecker::GameBasedMdpModelChecker(storm::storage::SymbolicModelDescription const& model, std::shared_ptr const& smtSolverFactory) : smtSolverFactory(smtSolverFactory) { + GameBasedMdpModelChecker::GameBasedMdpModelChecker(storm::storage::SymbolicModelDescription const& model, std::shared_ptr const& smtSolverFactory) : smtSolverFactory(smtSolverFactory), comparator(storm::settings::getModule().getPrecision()) { STORM_LOG_THROW(model.isPrismProgram(), storm::exceptions::NotSupportedException, "Currently only PRISM models are supported by the game-based model checker."); storm::prism::Program const& originalProgram = model.asPrismProgram(); STORM_LOG_THROW(originalProgram.getModelType() == storm::prism::Program::ModelType::DTMC || originalProgram.getModelType() == storm::prism::Program::ModelType::MDP, storm::exceptions::NotSupportedException, "Currently only DTMCs/MDPs are supported by the game-based model checker."); @@ -165,11 +169,11 @@ namespace storm { } template - std::unique_ptr checkForResultAfterQuantitativeCheck(CheckTask const& checkTask, ValueType const& minValue, ValueType const& maxValue) { + std::unique_ptr checkForResultAfterQuantitativeCheck(CheckTask const& checkTask, ValueType const& minValue, ValueType const& maxValue, storm::utility::ConstantsComparator const& comparator) { std::unique_ptr result; // If the lower and upper bounds are close enough, we can return the result. - if (maxValue - minValue < storm::utility::convertNumber(1e-3)) { + if (comparator.isEqual(minValue, maxValue)) { result = std::make_unique>(storm::storage::sparse::state_type(0), (minValue + maxValue) / ValueType(2)); } @@ -304,7 +308,7 @@ namespace storm { } // #ifdef LOCAL_DEBUG - abstractor.exportToDot("game" + std::to_string(iterations) + ".dot", targetStates, game.getManager().getBddOne()); + // abstractor.exportToDot("game" + std::to_string(iterations) + ".dot", targetStates, game.getManager().getBddOne()); // #endif // (3) compute all states with probability 0/1 wrt. to the two different player 2 goals (min/max). @@ -312,6 +316,7 @@ namespace storm { QualitativeResultMinMax qualitativeResult; std::unique_ptr result = computeProb01States(checkTask, qualitativeResult, game, player1Direction, transitionMatrixBdd, initialStates, constraintStates, targetStates); if (result) { + printStatistics(abstractor, game); return result; } auto qualitativeEnd = std::chrono::high_resolution_clock::now(); @@ -330,6 +335,7 @@ namespace storm { result = checkForResultAfterQualitativeCheck(checkTask, initialStates, qualitativeResult); if (result) { + printStatistics(abstractor, game); return result; } else { STORM_LOG_DEBUG("Obtained qualitative bounds [0, 1] on the actual value for the initial states."); @@ -355,6 +361,7 @@ namespace storm { quantitativeResult.min = computeQuantitativeResult(player1Direction, storm::OptimizationDirection::Minimize, game, qualitativeResult, initialStatesAdd, maybeMin); result = checkForResultAfterQuantitativeCheck(checkTask, storm::OptimizationDirection::Minimize, quantitativeResult.min.initialStateValue); if (result) { + printStatistics(abstractor, game); return result; } @@ -362,6 +369,7 @@ namespace storm { quantitativeResult.max = computeQuantitativeResult(player1Direction, storm::OptimizationDirection::Maximize, game, qualitativeResult, initialStatesAdd, maybeMax, boost::make_optional(quantitativeResult.min)); result = checkForResultAfterQuantitativeCheck(checkTask, storm::OptimizationDirection::Maximize, quantitativeResult.max.initialStateValue); if (result) { + printStatistics(abstractor, game); return result; } @@ -369,8 +377,9 @@ namespace storm { STORM_LOG_DEBUG("Obtained quantitative bounds [" << quantitativeResult.min.initialStateValue << ", " << quantitativeResult.max.initialStateValue << "] on the actual value for the initial states in " << std::chrono::duration_cast(quantitativeEnd - quantitativeStart).count() << "ms."); // (9) Check whether the lower and upper bounds are close enough to terminate with an answer. - result = checkForResultAfterQuantitativeCheck(checkTask, quantitativeResult.min.initialStateValue, quantitativeResult.max.initialStateValue); + result = checkForResultAfterQuantitativeCheck(checkTask, quantitativeResult.min.initialStateValue, quantitativeResult.max.initialStateValue, comparator); if (result) { + printStatistics(abstractor, game); return result; } @@ -420,15 +429,10 @@ namespace storm { qualitativeResult.prob0Min = computeProb01States(true, player1Direction, storm::OptimizationDirection::Minimize, game, transitionMatrixBdd, constraintStates, targetStates); qualitativeResult.prob1Min = computeProb01States(false, player1Direction, storm::OptimizationDirection::Minimize, game, transitionMatrixBdd, constraintStates, targetStates); std::unique_ptr result = checkForResultAfterQualitativeCheck(checkTask, storm::OptimizationDirection::Minimize, initialStates, qualitativeResult.prob0Min.getPlayer1States(), qualitativeResult.prob1Min.getPlayer1States()); - if (result) { - return result; - } - - qualitativeResult.prob0Max = computeProb01States(true, player1Direction, storm::OptimizationDirection::Maximize, game, transitionMatrixBdd, constraintStates, targetStates); - qualitativeResult.prob1Max = computeProb01States(false, player1Direction, storm::OptimizationDirection::Maximize, game, transitionMatrixBdd, constraintStates, targetStates); - result = checkForResultAfterQualitativeCheck(checkTask, storm::OptimizationDirection::Maximize, initialStates, qualitativeResult.prob0Max.getPlayer1States(), qualitativeResult.prob1Max.getPlayer1States()); - if (result) { - return result; + if (!result) { + qualitativeResult.prob0Max = computeProb01States(true, player1Direction, storm::OptimizationDirection::Maximize, game, transitionMatrixBdd, constraintStates, targetStates); + qualitativeResult.prob1Max = computeProb01States(false, player1Direction, storm::OptimizationDirection::Maximize, game, transitionMatrixBdd, constraintStates, targetStates); + result = checkForResultAfterQualitativeCheck(checkTask, storm::OptimizationDirection::Maximize, initialStates, qualitativeResult.prob0Max.getPlayer1States(), qualitativeResult.prob1Max.getPlayer1States()); } return result; } @@ -454,6 +458,19 @@ namespace storm { return result; } + template + void GameBasedMdpModelChecker::printStatistics(storm::abstraction::MenuGameAbstractor const& abstractor, storm::abstraction::MenuGame const& game) const { + if (storm::settings::getModule().isShowStatisticsSet()) { + storm::abstraction::AbstractionInformation const& abstractionInformation = abstractor.getAbstractionInformation(); + + std::cout << std::endl; + std::cout << "Statistics:" << std::endl; + std::cout << " * player 1 states (final game): " << game.getReachableStates().getNonZeroCount() << std::endl; + std::cout << " * transitions (final game): " << game.getTransitionMatrix().getNonZeroCount() << std::endl; + std::cout << " * predicates used in abstraction: " << abstractionInformation.getNumberOfPredicates() << std::endl; + } + } + template storm::expressions::Expression GameBasedMdpModelChecker::getExpression(storm::logic::Formula const& formula) { STORM_LOG_THROW(formula.isBooleanLiteralFormula() || formula.isAtomicExpressionFormula() || formula.isAtomicLabelFormula(), storm::exceptions::InvalidPropertyException, "The target states have to be given as label or an expression."); diff --git a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.h b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.h index be840c45e..07a3dbce8 100644 --- a/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.h +++ b/src/storm/modelchecker/abstraction/GameBasedMdpModelChecker.h @@ -14,6 +14,7 @@ #include "storm/logic/Bound.h" +#include "storm/utility/ConstantsComparator.h" #include "storm/utility/solver.h" #include "storm/utility/graph.h" @@ -21,6 +22,9 @@ namespace storm { namespace abstraction { template class MenuGame; + + template + class MenuGameAbstractor; } namespace modelchecker { @@ -70,6 +74,8 @@ namespace storm { std::unique_ptr computeProb01States(CheckTask const& checkTask, QualitativeResultMinMax& qualitativeResult, storm::abstraction::MenuGame const& game, storm::OptimizationDirection player1Direction, storm::dd::Bdd const& transitionMatrixBdd, storm::dd::Bdd const& initialStates, storm::dd::Bdd const& constraintStates, storm::dd::Bdd const& targetStates); QualitativeResult computeProb01States(bool prob0, storm::OptimizationDirection player1Direction, storm::OptimizationDirection player2Direction, storm::abstraction::MenuGame const& game, storm::dd::Bdd const& transitionMatrixBdd, storm::dd::Bdd const& constraintStates, storm::dd::Bdd const& targetStates); + void printStatistics(storm::abstraction::MenuGameAbstractor const& abstractor, storm::abstraction::MenuGame const& game) const; + /* * Retrieves the expression characterized by the formula. The formula needs to be propositional. */ @@ -81,6 +87,9 @@ namespace storm { /// A factory that is used for creating SMT solvers when needed. std::shared_ptr smtSolverFactory; + + /// A comparator that can be used for detecting convergence. + storm::utility::ConstantsComparator comparator; }; } } diff --git a/src/storm/settings/modules/AbstractionSettings.cpp b/src/storm/settings/modules/AbstractionSettings.cpp index 7f55ffd56..9b58fb49c 100644 --- a/src/storm/settings/modules/AbstractionSettings.cpp +++ b/src/storm/settings/modules/AbstractionSettings.cpp @@ -2,6 +2,8 @@ #include "storm/settings/Option.h" #include "storm/settings/OptionBuilder.h" +#include "storm/settings/ArgumentBuilder.h" +#include "storm/settings/Argument.h" namespace storm { namespace settings { @@ -15,7 +17,8 @@ namespace storm { const std::string AbstractionSettings::useInterpolationOptionName = "interpolation"; const std::string AbstractionSettings::splitInterpolantsOptionName = "split-interpolants"; const std::string AbstractionSettings::splitAllOptionName = "split-all"; - + const std::string AbstractionSettings::precisionOptionName = "precision"; + AbstractionSettings::AbstractionSettings() : ModuleSettings(moduleName) { this->addOption(storm::settings::OptionBuilder(moduleName, addAllGuardsOptionName, true, "Sets whether all guards are added as initial predicates.").build()); this->addOption(storm::settings::OptionBuilder(moduleName, splitPredicatesOptionName, true, "Sets whether the predicates are split into atoms before they are added.").build()); @@ -23,6 +26,7 @@ namespace storm { this->addOption(storm::settings::OptionBuilder(moduleName, splitGuardsOptionName, true, "Sets whether the guards are split into atoms before they are added.").build()); this->addOption(storm::settings::OptionBuilder(moduleName, splitAllOptionName, true, "Sets whether all predicates are split into atoms before they are added.").build()); this->addOption(storm::settings::OptionBuilder(moduleName, useInterpolationOptionName, true, "Sets whether interpolation is to be used to eliminate spurious pivot blocks.").build()); + this->addOption(storm::settings::OptionBuilder(moduleName, precisionOptionName, true, "The precision used for detecting convergence.").addArgument(storm::settings::ArgumentBuilder::createDoubleArgument("value", "The precision to achieve.").setDefaultValueDouble(1e-03).addValidationFunctionDouble(storm::settings::ArgumentValidators::doubleRangeValidatorExcluding(0.0, 1.0)).build()).build()); } bool AbstractionSettings::isAddAllGuardsSet() const { @@ -49,6 +53,9 @@ namespace storm { return this->getOption(useInterpolationOptionName).getHasOptionBeenSet(); } + double AbstractionSettings::getPrecision() const { + return this->getOption(precisionOptionName).getArgumentByName("value").getValueAsDouble(); + } } } } diff --git a/src/storm/settings/modules/AbstractionSettings.h b/src/storm/settings/modules/AbstractionSettings.h index 35f577cfd..9f95a5640 100644 --- a/src/storm/settings/modules/AbstractionSettings.h +++ b/src/storm/settings/modules/AbstractionSettings.h @@ -57,6 +57,13 @@ namespace storm { * @return True iff the option was set. */ bool isUseInterpolationSet() const; + + /*! + * Retrieves the precision that is used for detecting convergence. + * + * @return The precision to use for detecting convergence. + */ + double getPrecision() const; const static std::string moduleName; @@ -68,6 +75,7 @@ namespace storm { const static std::string useInterpolationOptionName; const static std::string splitInterpolantsOptionName; const static std::string splitAllOptionName; + const static std::string precisionOptionName; }; }