diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp index a370b0b29..51051415b 100644 --- a/src/storm/abstraction/MenuGameRefiner.cpp +++ b/src/storm/abstraction/MenuGameRefiner.cpp @@ -11,6 +11,18 @@ namespace storm { namespace abstraction { + RefinementPredicates::RefinementPredicates(Source const& source, std::vector const& predicates) : source(source), predicates(predicates) { + // Intentionally left empty. + } + + RefinementPredicates::Source RefinementPredicates::getSource() const { + return source; + } + + std::vector const& RefinementPredicates::getPredicates() const { + return predicates; + } + template MenuGameRefiner::MenuGameRefiner(MenuGameAbstractor& abstractor, std::unique_ptr&& smtSolver) : abstractor(abstractor), splitPredicates(storm::settings::getModule().isSplitPredicatesSet()), splitGuards(storm::settings::getModule().isSplitGuardsSet()), splitter(), equivalenceChecker(std::move(smtSolver)) { @@ -31,7 +43,7 @@ namespace storm { } template - storm::dd::Bdd pickPivotStateWithMinimalDistance(storm::dd::Bdd const& initialStates, storm::dd::Bdd const& transitionsMin, storm::dd::Bdd const& transitionsMax, std::set const& rowVariables, std::set const& columnVariables, storm::dd::Bdd const& pivotStates, boost::optional> const& quantitativeResult = boost::none) { + std::pair, storm::OptimizationDirection> pickPivotState(storm::dd::Bdd const& initialStates, storm::dd::Bdd const& transitionsMin, storm::dd::Bdd const& transitionsMax, std::set const& rowVariables, std::set const& columnVariables, storm::dd::Bdd const& pivotStates, boost::optional> const& quantitativeResult = boost::none) { // Set up used variables. storm::dd::Bdd frontierMin = initialStates; @@ -43,7 +55,7 @@ namespace storm { bool foundPivotState = !frontierPivotStates.isZero(); if (foundPivotState) { STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); - return frontierPivotStates.existsAbstractRepresentative(rowVariables); + return std::make_pair(frontierPivotStates.existsAbstractRepresentative(rowVariables), storm::OptimizationDirection::Minimize); } else { // Otherwise, we perform a simulatenous BFS in the sense that we make one step in both the min and max @@ -52,17 +64,40 @@ namespace storm { frontierMin = frontierMin.relationalProduct(transitionsMin, rowVariables, columnVariables); frontierMax = frontierMax.relationalProduct(transitionsMax, rowVariables, columnVariables); - frontierPivotStates = (frontierMin && pivotStates) || (frontierMax && pivotStates); + storm::dd::Bdd frontierMinPivotStates = frontierMin && pivotStates; + storm::dd::Bdd frontierMaxPivotStates = frontierMax && pivotStates; + uint64_t numberOfPivotStateCandidatesOnLevel = frontierMinPivotStates.getNonZeroCount() + frontierMaxPivotStates.getNonZeroCount(); - if (!frontierPivotStates.isZero()) { + if (!frontierMinPivotStates.isZero() || !frontierMaxPivotStates.isZero()) { if (quantitativeResult) { - storm::dd::Add frontierPivotStatesAdd = frontierPivotStates.template toAdd(); - storm::dd::Add diff = frontierPivotStatesAdd * quantitativeResult.get().max.values - frontierPivotStatesAdd * quantitativeResult.get().min.values; - STORM_LOG_TRACE("Picked pivot state with difference " << diff.getMax() << " from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); - return diff.maxAbstractRepresentative(rowVariables); + storm::dd::Add frontierMinPivotStatesAdd = frontierMinPivotStates.template toAdd(); + storm::dd::Add frontierMaxPivotStatesAdd = frontierMaxPivotStates.template toAdd(); + storm::dd::Add diffMin = frontierMinPivotStatesAdd * quantitativeResult.get().max.values - frontierMinPivotStatesAdd * quantitativeResult.get().min.values; + storm::dd::Add diffMax = frontierMaxPivotStatesAdd * quantitativeResult.get().max.values - frontierMaxPivotStatesAdd * quantitativeResult.get().min.values; + + ValueType diffValue; + storm::OptimizationDirection direction; + if (diffMin.getMax() >= diffMax.getMax()) { + direction = storm::OptimizationDirection::Minimize; + diffValue = diffMin.getMax(); + } else { + direction = storm::OptimizationDirection::Maximize; + diffValue = diffMax.getMax(); + } + + STORM_LOG_TRACE("Picked pivot state with difference " << diffValue << " from " << numberOfPivotStateCandidatesOnLevel << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); + return std::make_pair(direction == storm::OptimizationDirection::Minimize ? diffMin.maxAbstractRepresentative(rowVariables) : diffMax.maxAbstractRepresentative(rowVariables), direction); } else { - STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); - return frontierPivotStates.existsAbstractRepresentative(rowVariables); + STORM_LOG_TRACE("Picked pivot state from " << numberOfPivotStateCandidatesOnLevel << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); + + storm::OptimizationDirection direction; + if (!frontierMinPivotStates.isZero()) { + direction = storm::OptimizationDirection::Minimize; + } else { + direction = storm::OptimizationDirection::Maximize; + } + + return std::make_pair(direction == storm::OptimizationDirection::Minimize ? frontierMinPivotStates.existsAbstractRepresentative(rowVariables) : frontierMaxPivotStates.existsAbstractRepresentative(rowVariables), direction); } } ++level; @@ -70,11 +105,11 @@ namespace storm { } STORM_LOG_ASSERT(false, "This point must not be reached, because then no pivot state could be found."); - return storm::dd::Bdd(); + return std::make_pair(storm::dd::Bdd(), storm::OptimizationDirection::Minimize); } - + template - std::pair MenuGameRefiner::derivePredicateFromDifferingChoices(storm::dd::Bdd const& pivotState, storm::dd::Bdd const& player1Choice, storm::dd::Bdd const& lowerChoice, storm::dd::Bdd const& upperChoice) const { + RefinementPredicates MenuGameRefiner::derivePredicatesFromDifferingChoices(storm::dd::Bdd const& pivotState, storm::dd::Bdd const& player1Choice, storm::dd::Bdd const& lowerChoice, storm::dd::Bdd const& upperChoice) const { // Prepare result. storm::expressions::Expression newPredicate; bool fromGuard = false; @@ -132,7 +167,7 @@ namespace storm { for (auto const& predicate : abstractionInformation.getPredicates()) { STORM_LOG_TRACE(predicate); } - return std::make_pair(newPredicate, fromGuard); + return RefinementPredicates(fromGuard ? RefinementPredicates::Source::Guard : RefinementPredicates::Source::WeakestPrecondition, {newPredicate}); } template @@ -153,7 +188,7 @@ namespace storm { // Start with all reachable states as potential pivot states. result.pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), result.reachableTransitionsMin, game.getRowVariables(), game.getColumnVariables()) || - storm::utility::dd::computeReachableStates(game.getInitialStates(), result.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables()); + storm::utility::dd::computeReachableStates(game.getInitialStates(), result.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables()); // Then constrain these states by the requirement that for either the lower or upper player 1 choice the player 2 choices must be different and // that the difference is not because of a missing strategy in either case. @@ -171,7 +206,7 @@ namespace storm { } template - std::pair MenuGameRefiner::derivePredicateFromPivotState(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& pivotState, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const { + RefinementPredicates MenuGameRefiner::derivePredicatesFromPivotState(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& pivotState, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const { // Compute the lower and the upper choice for the pivot state. std::set variablesToAbstract = game.getNondeterminismVariables(); variablesToAbstract.insert(game.getRowVariables().begin(), game.getRowVariables().end()); @@ -184,10 +219,10 @@ namespace storm { STORM_LOG_TRACE("Refining based on lower choice."); auto refinementStart = std::chrono::high_resolution_clock::now(); - std::pair newPredicate = derivePredicateFromDifferingChoices(pivotState, (pivotState && minPlayer1Strategy).existsAbstract(game.getRowVariables()), lowerChoice1, lowerChoice2); + RefinementPredicates predicates = derivePredicatesFromDifferingChoices(pivotState, (pivotState && minPlayer1Strategy).existsAbstract(game.getRowVariables()), lowerChoice1, lowerChoice2); auto refinementEnd = std::chrono::high_resolution_clock::now(); STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast(refinementEnd - refinementStart).count() << "ms."); - return newPredicate; + return predicates; } else { storm::dd::Bdd upperChoice = pivotState && game.getExtendedTransitionMatrix().toBdd() && maxPlayer1Strategy; storm::dd::Bdd upperChoice1 = (upperChoice && minPlayer2Strategy).existsAbstract(variablesToAbstract); @@ -197,10 +232,10 @@ namespace storm { if (upperChoicesDifferent) { STORM_LOG_TRACE("Refining based on upper choice."); auto refinementStart = std::chrono::high_resolution_clock::now(); - std::pair newPredicate = derivePredicateFromDifferingChoices(pivotState, (pivotState && maxPlayer1Strategy).existsAbstract(game.getRowVariables()), upperChoice1, upperChoice2); + RefinementPredicates predicates = derivePredicatesFromDifferingChoices(pivotState, (pivotState && maxPlayer1Strategy).existsAbstract(game.getRowVariables()), upperChoice1, upperChoice2); auto refinementEnd = std::chrono::high_resolution_clock::now(); STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast(refinementEnd - refinementStart).count() << "ms."); - return newPredicate; + return predicates; } else { STORM_LOG_ASSERT(false, "Did not find choices from which to derive predicates."); } @@ -233,11 +268,11 @@ namespace storm { STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to proceed without pivot state candidates."); // Now that we have the pivot state candidates, we need to pick one. - storm::dd::Bdd pivotState = pickPivotStateWithMinimalDistance(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); + std::pair, storm::OptimizationDirection> pivotState = pickPivotState(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); // Derive predicate based on the selected pivot state. - std::pair newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - std::vector preparedPredicates = preprocessPredicates({newPredicate.first}, (newPredicate.second && splitGuards) || (!newPredicate.second && splitPredicates)); + RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotState.first, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + std::vector preparedPredicates = preprocessPredicates(predicates.getPredicates(), (predicates.getSource() == RefinementPredicates::Source::Guard && splitGuards) || (predicates.getSource() == RefinementPredicates::Source::WeakestPrecondition && splitPredicates)); performRefinement(createGlobalRefinement(preparedPredicates)); return true; } @@ -253,19 +288,15 @@ namespace storm { // Compute all reached pivot states. PivotStateResult pivotStateResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - - // TODO: required? - // Require the pivot state to be a state with a lower bound strictly smaller than the upper bound. - pivotStateResult.pivotStates &= quantitativeResult.min.values.less(quantitativeResult.max.values); STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to refine without pivot state candidates."); - + // Now that we have the pivot state candidates, we need to pick one. - storm::dd::Bdd pivotState = pickPivotStateWithMinimalDistance(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); - + std::pair, storm::OptimizationDirection> pivotState = pickPivotState(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); + // Derive predicate based on the selected pivot state. - std::pair newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - std::vector preparedPredicates = preprocessPredicates({newPredicate.first}, (newPredicate.second && splitGuards) || (!newPredicate.second && splitPredicates)); + RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotState.first, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + std::vector preparedPredicates = preprocessPredicates(predicates.getPredicates(), (predicates.getSource() == RefinementPredicates::Source::Guard && splitGuards) || (predicates.getSource() == RefinementPredicates::Source::WeakestPrecondition && splitPredicates)); performRefinement(createGlobalRefinement(preparedPredicates)); return true; } @@ -274,10 +305,10 @@ namespace storm { std::vector MenuGameRefiner::preprocessPredicates(std::vector const& predicates, bool split) const { if (split) { std::vector cleanedAtoms; - + for (auto const& predicate : predicates) { AbstractionInformation const& abstractionInformation = abstractor.get().getAbstractionInformation(); - + // Split the predicates. std::vector atoms = splitter.split(predicate); diff --git a/src/storm/abstraction/MenuGameRefiner.h b/src/storm/abstraction/MenuGameRefiner.h index e5be859c4..54c778ae2 100644 --- a/src/storm/abstraction/MenuGameRefiner.h +++ b/src/storm/abstraction/MenuGameRefiner.h @@ -25,6 +25,22 @@ namespace storm { template class MenuGame; + class RefinementPredicates { + public: + enum class Source { + WeakestPrecondition, Guard, Interpolation + }; + + RefinementPredicates(Source const& source, std::vector const& predicates); + + Source getSource() const; + std::vector const& getPredicates() const; + + private: + Source source; + std::vector predicates; + }; + template class MenuGameRefiner { public: @@ -53,13 +69,13 @@ namespace storm { bool refine(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& transitionMatrixBdd, QuantitativeResultMinMax const& quantitativeResult) const; private: - std::pair derivePredicateFromDifferingChoices(storm::dd::Bdd const& pivotState, storm::dd::Bdd const& player1Choice, storm::dd::Bdd const& lowerChoice, storm::dd::Bdd const& upperChoice) const; - std::pair derivePredicateFromPivotState(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& pivotState, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const; + RefinementPredicates derivePredicatesFromDifferingChoices(storm::dd::Bdd const& pivotState, storm::dd::Bdd const& player1Choice, storm::dd::Bdd const& lowerChoice, storm::dd::Bdd const& upperChoice) const; + RefinementPredicates derivePredicatesFromPivotState(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& pivotState, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const; /*! * Preprocesses the predicates. */ - std::vector preprocessPredicates(std::vector const& predicates, bool allowSplits) const; + std::vector preprocessPredicates(std::vector const& predicates, bool split) const; /*! * Creates a set of refinement commands that amounts to splitting all player 1 choices with the given set of predicates.