From 0f9c64f1a9d6666a9beffb0d1e6fa423ffa57c96 Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Sun, 27 Nov 2016 10:19:32 +0100
Subject: [PATCH] some refactoring of refiner

---
 src/storm/abstraction/MenuGameRefiner.cpp | 242 +++++++++-------------
 src/storm/abstraction/MenuGameRefiner.h   |   4 +-
 2 files changed, 100 insertions(+), 146 deletions(-)

diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp
index e9a682c32..ca26fd8dd 100644
--- a/src/storm/abstraction/MenuGameRefiner.cpp
+++ b/src/storm/abstraction/MenuGameRefiner.cpp
@@ -22,61 +22,54 @@ namespace storm {
         }
         
         template<storm::dd::DdType Type, typename ValueType>
-        storm::dd::Bdd<Type> pickPivotState(storm::dd::Bdd<Type> const& initialStates, storm::dd::Bdd<Type> const& transitionsMin, storm::dd::Bdd<Type> const& transitionsMax, std::set<storm::expressions::Variable> const& rowVariables, std::set<storm::expressions::Variable> const& columnVariables, storm::dd::Bdd<Type> const& pivotStates, boost::optional<QuantitativeResultMinMax<Type, ValueType>> const& quantitativeResult = boost::none) {
+        storm::dd::Bdd<Type> pickPivotStateWithMinimalDistance(storm::dd::Bdd<Type> const& initialStates, storm::dd::Bdd<Type> const& transitionsMin, storm::dd::Bdd<Type> const& transitionsMax, std::set<storm::expressions::Variable> const& rowVariables, std::set<storm::expressions::Variable> const& columnVariables, storm::dd::Bdd<Type> const& pivotStates, boost::optional<QuantitativeResultMinMax<Type, ValueType>> const& quantitativeResult = boost::none) {
             
-            // Perform a BFS and pick the first pivot state we encounter.
-            storm::dd::Bdd<Type> pivotState;
+            // Set up used variables.
             storm::dd::Bdd<Type> frontierMin = initialStates;
             storm::dd::Bdd<Type> frontierMax = initialStates;
-            storm::dd::Bdd<Type> frontierMinPivotStates = frontierMin && pivotStates;
-            storm::dd::Bdd<Type> frontierMaxPivotStates = frontierMinPivotStates;
+            storm::dd::Bdd<Type> frontierPivotStates = frontierMin && pivotStates;
             
+            // Check whether we have pivot states on the very first level.
             uint64_t level = 0;
-            bool foundPivotState = !frontierMinPivotStates.isZero();
+            bool foundPivotState = !frontierPivotStates.isZero();
             if (foundPivotState) {
-                pivotState = frontierMinPivotStates.existsAbstractRepresentative(rowVariables);
-                STORM_LOG_TRACE("Picked pivot state from " << frontierMinPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
+                STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
+                return frontierPivotStates.existsAbstractRepresentative(rowVariables);
             } else {
+                
+                // Otherwise, we perform a simulatenous BFS in the sense that we make one step in both the min and max
+                // transitions and check for pivot states we encounter.
                 while (!foundPivotState) {
                     frontierMin = frontierMin.relationalProduct(transitionsMin, rowVariables, columnVariables);
                     frontierMax = frontierMax.relationalProduct(transitionsMax, rowVariables, columnVariables);
-                    frontierMinPivotStates = frontierMin && pivotStates;
-                    frontierMaxPivotStates = frontierMax && pivotStates;
                     
-                    if (!frontierMinPivotStates.isZero()) {
-                        if (quantitativeResult) {
-                            storm::dd::Add<Type, ValueType> frontierPivotStatesAdd = frontierMinPivotStates.template toAdd<ValueType>();
-                            storm::dd::Add<Type, ValueType> diff = frontierPivotStatesAdd * quantitativeResult.get().max.values - frontierPivotStatesAdd * quantitativeResult.get().min.values;
-                            pivotState = diff.maxAbstractRepresentative(rowVariables);
-                            STORM_LOG_TRACE("Picked pivot state with difference " << diff.getMax() << " from " << (frontierMinPivotStates.getNonZeroCount() + frontierMaxPivotStates.getNonZeroCount()) << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
-                            foundPivotState = true;
-                        } else {
-                            pivotState = frontierMinPivotStates.existsAbstractRepresentative(rowVariables);
-                            STORM_LOG_TRACE("Picked pivot state from " << (frontierMinPivotStates.getNonZeroCount() + frontierMaxPivotStates.getNonZeroCount()) << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
-                            foundPivotState = true;
-                        }
-                    } else if (!frontierMaxPivotStates.isZero()) {
+                    frontierPivotStates = (frontierMin && pivotStates) || (frontierMax && pivotStates);
+                    
+                    if (!frontierPivotStates.isZero()) {
                         if (quantitativeResult) {
-                            storm::dd::Add<Type, ValueType> frontierPivotStatesAdd = frontierMaxPivotStates.template toAdd<ValueType>();
+                            storm::dd::Add<Type, ValueType> frontierPivotStatesAdd = frontierPivotStates.template toAdd<ValueType>();
                             storm::dd::Add<Type, ValueType> diff = frontierPivotStatesAdd * quantitativeResult.get().max.values - frontierPivotStatesAdd * quantitativeResult.get().min.values;
-                            pivotState = diff.maxAbstractRepresentative(rowVariables);
-                            STORM_LOG_TRACE("Picked pivot state with difference " << diff.getMax() << " from " << (frontierMinPivotStates.getNonZeroCount() + frontierMaxPivotStates.getNonZeroCount()) << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
-                            foundPivotState = true;
+                            STORM_LOG_TRACE("Picked pivot state with difference " << diff.getMax() << " from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
+                            return diff.maxAbstractRepresentative(rowVariables);
                         } else {
-                            pivotState = frontierMinPivotStates.existsAbstractRepresentative(rowVariables);
-                            STORM_LOG_TRACE("Picked pivot state from " << (frontierMinPivotStates.getNonZeroCount() + frontierMaxPivotStates.getNonZeroCount()) << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
-                            foundPivotState = true;
+                            STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total.");
+                            return frontierPivotStates.existsAbstractRepresentative(rowVariables);
                         }
                     }
                     ++level;
                 }
             }
             
-            return pivotState;
+            STORM_LOG_ASSERT(false, "This point must not be reached, because then no pivot state could be found.");
+            return storm::dd::Bdd<Type>();
         }
 
         template <storm::dd::DdType Type, typename ValueType>
-        void MenuGameRefiner<Type, ValueType>::refine(storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& player1Choice, storm::dd::Bdd<Type> const& lowerChoice, storm::dd::Bdd<Type> const& upperChoice) const {
+        storm::expressions::Expression MenuGameRefiner<Type, ValueType>::derivePredicateFromDifferingChoices(storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& player1Choice, storm::dd::Bdd<Type> const& lowerChoice, storm::dd::Bdd<Type> const& upperChoice) const {
+            // Prepare result.
+            storm::expressions::Expression newPredicate;
+            
+            // Get abstraction informatin for easier access.
             AbstractionInformation<Type> const& abstractionInformation = abstractor.get().getAbstractionInformation();
             
             // Decode the index of the command chosen by player 1.
@@ -92,9 +85,8 @@ namespace storm {
             // command (that is the player 1 choice).
             if (buttomStateSuccessor) {
                 STORM_LOG_TRACE("One of the successors is a bottom state, taking a guard as a new predicate.");
-                storm::expressions::Expression newPredicate = abstractor.get().getGuard(player1Index);
-                STORM_LOG_DEBUG("Derived new predicate: " << newPredicate);
-                this->performRefinement({newPredicate});
+                newPredicate = abstractor.get().getGuard(player1Index);
+                STORM_LOG_DEBUG("Derived new predicate (based on guard): " << newPredicate);
             } else {
                 STORM_LOG_TRACE("No bottom state successor. Deriving a new predicate using weakest precondition.");
                 
@@ -104,7 +96,6 @@ namespace storm {
                 STORM_LOG_ASSERT(lowerChoiceUpdateToSuccessorMapping.size() == upperChoiceUpdateToSuccessorMapping.size(), "Mismatching sizes after decode (" << lowerChoiceUpdateToSuccessorMapping.size() << " vs. " << upperChoiceUpdateToSuccessorMapping.size() << ").");
                 
                 // Now go through the mappings and find points of deviation. Currently, we take the first deviation.
-                storm::expressions::Expression newPredicate;
                 auto lowerIt = lowerChoiceUpdateToSuccessorMapping.begin();
                 auto lowerIte = lowerChoiceUpdateToSuccessorMapping.end();
                 auto upperIt = upperChoiceUpdateToSuccessorMapping.begin();
@@ -123,43 +114,36 @@ namespace storm {
                     }
                 }
                 STORM_LOG_ASSERT(newPredicate.isInitialized(), "Could not derive new predicate as there is no deviation.");
-                
-                STORM_LOG_DEBUG("Derived new predicate: " << newPredicate);
-                this->performRefinement({newPredicate});
+                STORM_LOG_DEBUG("Derived new predicate (based on weakest-precondition): " << newPredicate);
             }
             
             STORM_LOG_TRACE("Current set of predicates:");
             for (auto const& predicate : abstractionInformation.getPredicates()) {
                 STORM_LOG_TRACE(predicate);
             }
+            return newPredicate;
         }
         
+        template<storm::dd::DdType Type>
+        struct PivotStateResult {
+            storm::dd::Bdd<Type> reachableTransitionsMin;
+            storm::dd::Bdd<Type> reachableTransitionsMax;
+            storm::dd::Bdd<Type> pivotStates;
+        };
+        
         template<storm::dd::DdType Type, typename ValueType>
-        bool MenuGameRefiner<Type, ValueType>::refine(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& transitionMatrixBdd, QualitativeResultMinMax<Type> const& qualitativeResult) const {
-            STORM_LOG_TRACE("Trying refinement after qualitative check.");
-            // Get all relevant strategies.
-            storm::dd::Bdd<Type> minPlayer1Strategy = qualitativeResult.prob0Min.getPlayer1Strategy();
-            storm::dd::Bdd<Type> minPlayer2Strategy = qualitativeResult.prob0Min.getPlayer2Strategy();
-            storm::dd::Bdd<Type> maxPlayer1Strategy = qualitativeResult.prob1Max.getPlayer1Strategy();
-            storm::dd::Bdd<Type> maxPlayer2Strategy = qualitativeResult.prob1Max.getPlayer2Strategy();
+        PivotStateResult<Type> computePivotStates(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& transitionMatrixBdd, storm::dd::Bdd<Type> const& minPlayer1Strategy, storm::dd::Bdd<Type> const& minPlayer2Strategy, storm::dd::Bdd<Type> const& maxPlayer1Strategy, storm::dd::Bdd<Type> const& maxPlayer2Strategy) {
             
-            // Redirect all player 1 choices of the min strategy to that of the max strategy if this leads to a player 2
-            // state that is also a prob 0 state.
-            minPlayer1Strategy = (maxPlayer1Strategy && qualitativeResult.prob0Min.getPlayer2States()).existsAbstract(game.getPlayer1Variables()).ite(maxPlayer1Strategy, minPlayer1Strategy);
+            PivotStateResult<Type> result;
             
             // Build the fragment of transitions that is reachable by either the min or the max strategies.
-            storm::dd::Bdd<Type> reachableTransitions = transitionMatrixBdd && (minPlayer1Strategy || maxPlayer1Strategy) && minPlayer2Strategy && maxPlayer2Strategy;
-            reachableTransitions = reachableTransitions.existsAbstract(game.getNondeterminismVariables());
-
-            storm::dd::Bdd<Type> reachableTransitionsMin = (transitionMatrixBdd && minPlayer1Strategy && minPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
-            storm::dd::Bdd<Type> reachableTransitionsMax = (transitionMatrixBdd && maxPlayer1Strategy && maxPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
+            result.reachableTransitionsMin = (transitionMatrixBdd && minPlayer1Strategy && minPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
+            result.reachableTransitionsMax = (transitionMatrixBdd && maxPlayer1Strategy && maxPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
             
             // Start with all reachable states as potential pivot states.
-            storm::dd::Bdd<Type> pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitionsMin, game.getRowVariables(), game.getColumnVariables()) ||
-                                               storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables());
-
-            //storm::dd::Bdd<Type> pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables());
-
+            result.pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), result.reachableTransitionsMin, game.getRowVariables(), game.getColumnVariables()) ||
+                                 storm::utility::dd::computeReachableStates(game.getInitialStates(), result.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables());
+            
             // Then constrain these states by the requirement that for either the lower or upper player 1 choice the player 2 choices must be different and
             // that the difference is not because of a missing strategy in either case.
             
@@ -170,22 +154,13 @@ namespace storm {
             constraint &= minPlayer2Strategy.exclusiveOr(maxPlayer2Strategy);
             
             // Then restrict the pivot states by requiring existing and different player 2 choices.
-            // pivotStates &= ((minPlayer1Strategy || maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables());
-            pivotStates &= ((minPlayer1Strategy && maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables());
-            
-            // We can only refine in case we have a reachable player 1 state with a player 2 successor (under either
-            // player 1's min or max strategy) such that from this player 2 state, both prob0 min and prob1 max define
-            // strategies and they differ. Hence, it is possible that we arrive at a point where no suitable pivot state
-            // is found. In this case, we abort the qualitative refinement here.
-            if (pivotStates.isZero()) {
-                return false;
-            }
-            
-            STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to proceed without pivot state candidates.");
-            
-            // Now that we have the pivot state candidates, we need to pick one.
-            storm::dd::Bdd<Type> pivotState = pickPivotState<Type, ValueType>(game.getInitialStates(), reachableTransitionsMin, reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStates);
+            result.pivotStates &= ((minPlayer1Strategy && maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables());
             
+            return result;
+        }
+        
+        template<storm::dd::DdType Type, typename ValueType>
+        storm::expressions::Expression MenuGameRefiner<Type, ValueType>::derivePredicateFromPivotState(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& minPlayer1Strategy, storm::dd::Bdd<Type> const& minPlayer2Strategy, storm::dd::Bdd<Type> const& maxPlayer1Strategy, storm::dd::Bdd<Type> const& maxPlayer2Strategy) const {
             // Compute the lower and the upper choice for the pivot state.
             std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables();
             variablesToAbstract.insert(game.getRowVariables().begin(), game.getRowVariables().end());
@@ -198,10 +173,10 @@ namespace storm {
                 STORM_LOG_TRACE("Refining based on lower choice.");
                 auto refinementStart = std::chrono::high_resolution_clock::now();
                 
-                this->refine(pivotState, (pivotState && minPlayer1Strategy).existsAbstract(game.getRowVariables()), lowerChoice1, lowerChoice2);
+                storm::expressions::Expression newPredicate = derivePredicateFromDifferingChoices(pivotState, (pivotState && minPlayer1Strategy).existsAbstract(game.getRowVariables()), lowerChoice1, lowerChoice2);
                 auto refinementEnd = std::chrono::high_resolution_clock::now();
                 STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(refinementEnd - refinementStart).count() << "ms.");
-                return true;
+                return newPredicate;
             } else {
                 storm::dd::Bdd<Type> upperChoice = pivotState && game.getExtendedTransitionMatrix().toBdd() && maxPlayer1Strategy;
                 storm::dd::Bdd<Type> upperChoice1 = (upperChoice && minPlayer2Strategy).existsAbstract(variablesToAbstract);
@@ -211,15 +186,48 @@ namespace storm {
                 if (upperChoicesDifferent) {
                     STORM_LOG_TRACE("Refining based on upper choice.");
                     auto refinementStart = std::chrono::high_resolution_clock::now();
-                    this->refine(pivotState, (pivotState && maxPlayer1Strategy).existsAbstract(game.getRowVariables()), upperChoice1, upperChoice2);
+                    storm::expressions::Expression newPredicate = derivePredicateFromDifferingChoices(pivotState, (pivotState && maxPlayer1Strategy).existsAbstract(game.getRowVariables()), upperChoice1, upperChoice2);
                     auto refinementEnd = std::chrono::high_resolution_clock::now();
                     STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(refinementEnd - refinementStart).count() << "ms.");
-                    return true;
+                    return newPredicate;
                 } else {
                     STORM_LOG_ASSERT(false, "Did not find choices from which to derive predicates.");
                 }
             }
-            return false;
+        }
+        
+        template<storm::dd::DdType Type, typename ValueType>
+        bool MenuGameRefiner<Type, ValueType>::refine(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& transitionMatrixBdd, QualitativeResultMinMax<Type> const& qualitativeResult) const {
+            STORM_LOG_TRACE("Trying refinement after qualitative check.");
+            // Get all relevant strategies.
+            storm::dd::Bdd<Type> minPlayer1Strategy = qualitativeResult.prob0Min.getPlayer1Strategy();
+            storm::dd::Bdd<Type> minPlayer2Strategy = qualitativeResult.prob0Min.getPlayer2Strategy();
+            storm::dd::Bdd<Type> maxPlayer1Strategy = qualitativeResult.prob1Max.getPlayer1Strategy();
+            storm::dd::Bdd<Type> maxPlayer2Strategy = qualitativeResult.prob1Max.getPlayer2Strategy();
+            
+            // Redirect all player 1 choices of the min strategy to that of the max strategy if this leads to a player 2
+            // state that is also a prob 0 state.
+            minPlayer1Strategy = (maxPlayer1Strategy && qualitativeResult.prob0Min.getPlayer2States()).existsAbstract(game.getPlayer1Variables()).ite(maxPlayer1Strategy, minPlayer1Strategy);
+            
+            // Compute all reached pivot states.
+            PivotStateResult<Type> pivotStateResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy);
+            
+            // We can only refine in case we have a reachable player 1 state with a player 2 successor (under either
+            // player 1's min or max strategy) such that from this player 2 state, both prob0 min and prob1 max define
+            // strategies and they differ. Hence, it is possible that we arrive at a point where no suitable pivot state
+            // is found. In this case, we abort the qualitative refinement here.
+            if (pivotStateResult.pivotStates.isZero()) {
+                return false;
+            }
+            STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to proceed without pivot state candidates.");
+            
+            // Now that we have the pivot state candidates, we need to pick one.
+            storm::dd::Bdd<Type> pivotState = pickPivotStateWithMinimalDistance<Type, ValueType>(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates);
+            
+            // Derive predicate based on the selected pivot state.
+            storm::expressions::Expression newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy);
+            performRefinement({newPredicate});
+            return true;
         }
         
         template<storm::dd::DdType Type, typename ValueType>
@@ -231,75 +239,21 @@ namespace storm {
             storm::dd::Bdd<Type> maxPlayer1Strategy = quantitativeResult.max.player1Strategy;
             storm::dd::Bdd<Type> maxPlayer2Strategy = quantitativeResult.max.player2Strategy;
             
-            // TODO: fix min strategies to take the max strategies if possible.
-            
-            // Build the fragment of transitions that is reachable by both the min and the max strategies.
-            storm::dd::Bdd<Type> reachableTransitions = transitionMatrixBdd && (minPlayer1Strategy || maxPlayer1Strategy) && minPlayer2Strategy && maxPlayer2Strategy;
-            reachableTransitions = reachableTransitions.existsAbstract(game.getNondeterminismVariables());
-            
-            storm::dd::Bdd<Type> reachableTransitionsMin = (transitionMatrixBdd && minPlayer1Strategy && minPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
-            storm::dd::Bdd<Type> reachableTransitionsMax = (transitionMatrixBdd && maxPlayer1Strategy && maxPlayer2Strategy).existsAbstract(game.getNondeterminismVariables());
-            
-            // Start with all reachable states as potential pivot states.
-            // storm::dd::Bdd<Type> pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitions, game.getRowVariables(), game.getColumnVariables());
-            storm::dd::Bdd<Type> pivotStates = storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitionsMin, game.getRowVariables(), game.getColumnVariables()) ||
-            storm::utility::dd::computeReachableStates(game.getInitialStates(), reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables());
+            // Compute all reached pivot states.
+            PivotStateResult<Type> pivotStateResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy);
 
-            
+            // TODO: required?
             // Require the pivot state to be a state with a lower bound strictly smaller than the upper bound.
-            pivotStates &= quantitativeResult.min.values.less(quantitativeResult.max.values);
-            
-            STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to refine without pivot state candidates.");
-            
-            // Then constrain these states by the requirement that for either the lower or upper player 1 choice the player 2 choices must be different and
-            // that the difference is not because of a missing strategy in either case.
-            
-            // Start with constructing the player 2 states that have a (min) and a (max) strategy.
-            // TODO: necessary?
-            storm::dd::Bdd<Type> constraint = minPlayer2Strategy.existsAbstract(game.getPlayer2Variables()) && maxPlayer2Strategy.existsAbstract(game.getPlayer2Variables());
-            
-            // Now construct all player 2 choices that actually exist and differ in the min and max case.
-            constraint &= minPlayer2Strategy.exclusiveOr(maxPlayer2Strategy);
-            
-            // Then restrict the pivot states by requiring existing and different player 2 choices.
-            // pivotStates &= ((minPlayer1Strategy || maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables());
-            pivotStates &= ((minPlayer1Strategy && maxPlayer1Strategy) && constraint).existsAbstract(game.getNondeterminismVariables());
-            
-            STORM_LOG_ASSERT(!pivotStates.isZero(), "Unable to refine without pivot state candidates.");
+            pivotStateResult.pivotStates &= quantitativeResult.min.values.less(quantitativeResult.max.values);
             
+            STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to refine without pivot state candidates.");
+
             // Now that we have the pivot state candidates, we need to pick one.
-            storm::dd::Bdd<Type> pivotState = pickPivotState<Type, ValueType>(game.getInitialStates(), reachableTransitionsMin, reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStates, quantitativeResult);
-            
-            // Compute the lower and the upper choice for the pivot state.
-            std::set<storm::expressions::Variable> variablesToAbstract = game.getNondeterminismVariables();
-            variablesToAbstract.insert(game.getRowVariables().begin(), game.getRowVariables().end());
-            storm::dd::Bdd<Type> lowerChoice = pivotState && game.getExtendedTransitionMatrix().toBdd() && minPlayer1Strategy;
-            storm::dd::Bdd<Type> lowerChoice1 = (lowerChoice && minPlayer2Strategy).existsAbstract(variablesToAbstract);
-            storm::dd::Bdd<Type> lowerChoice2 = (lowerChoice && maxPlayer2Strategy).existsAbstract(variablesToAbstract);
-            
-            bool lowerChoicesDifferent = !lowerChoice1.exclusiveOr(lowerChoice2).isZero();
-            if (lowerChoicesDifferent) {
-                STORM_LOG_TRACE("Refining based on lower choice.");
-                auto refinementStart = std::chrono::high_resolution_clock::now();
-                this->refine(pivotState, (pivotState && minPlayer1Strategy).existsAbstract(game.getRowVariables()), lowerChoice1, lowerChoice2);
-                auto refinementEnd = std::chrono::high_resolution_clock::now();
-                STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(refinementEnd - refinementStart).count() << "ms.");
-            } else {
-                storm::dd::Bdd<Type> upperChoice = pivotState && game.getExtendedTransitionMatrix().toBdd() && maxPlayer1Strategy;
-                storm::dd::Bdd<Type> upperChoice1 = (upperChoice && minPlayer2Strategy).existsAbstract(variablesToAbstract);
-                storm::dd::Bdd<Type> upperChoice2 = (upperChoice && maxPlayer2Strategy).existsAbstract(variablesToAbstract);
-                
-                bool upperChoicesDifferent = !upperChoice1.exclusiveOr(upperChoice2).isZero();
-                if (upperChoicesDifferent) {
-                    STORM_LOG_TRACE("Refining based on upper choice.");
-                    auto refinementStart = std::chrono::high_resolution_clock::now();
-                    this->refine(pivotState, (pivotState && maxPlayer1Strategy).existsAbstract(game.getRowVariables()), upperChoice1, upperChoice2);
-                    auto refinementEnd = std::chrono::high_resolution_clock::now();
-                    STORM_LOG_TRACE("Refinement completed in " << std::chrono::duration_cast<std::chrono::milliseconds>(refinementEnd - refinementStart).count() << "ms.");
-                } else {
-                    STORM_LOG_ASSERT(false, "Did not find choices from which to derive predicates.");
-                }
-            }
+            storm::dd::Bdd<Type> pivotState = pickPivotStateWithMinimalDistance<Type, ValueType>(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates);
+
+            // Derive predicate based on the selected pivot state.
+            storm::expressions::Expression newPredicate = derivePredicateFromPivotState(game, pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy);
+            performRefinement({newPredicate});
             return true;
         }
         
diff --git a/src/storm/abstraction/MenuGameRefiner.h b/src/storm/abstraction/MenuGameRefiner.h
index 5e866f860..7394c14b3 100644
--- a/src/storm/abstraction/MenuGameRefiner.h
+++ b/src/storm/abstraction/MenuGameRefiner.h
@@ -52,8 +52,8 @@ namespace storm {
             bool refine(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& transitionMatrixBdd, QuantitativeResultMinMax<Type, ValueType> const& quantitativeResult) const;
             
         private:
-            void refine(storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& player1Choice, storm::dd::Bdd<Type> const& lowerChoice, storm::dd::Bdd<Type> const& upperChoice) const;
-            
+            storm::expressions::Expression derivePredicateFromDifferingChoices(storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& player1Choice, storm::dd::Bdd<Type> const& lowerChoice, storm::dd::Bdd<Type> const& upperChoice) const;
+            storm::expressions::Expression derivePredicateFromPivotState(storm::abstraction::MenuGame<Type, ValueType> const& game, storm::dd::Bdd<Type> const& pivotState, storm::dd::Bdd<Type> const& minPlayer1Strategy, storm::dd::Bdd<Type> const& minPlayer2Strategy, storm::dd::Bdd<Type> const& maxPlayer1Strategy, storm::dd::Bdd<Type> const& maxPlayer2Strategy) const;
             /*!
              * Takes the given predicates, preprocesses them and then refines the abstractor.
              */