From 2cc5b6e0804d919bad3c4d3cf9261ad667b74b2c Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Fri, 4 Oct 2013 19:48:43 +0200
Subject: [PATCH] Added Z3ExpressionAdapter to translate IR expressions to the
 Z3 format. Improvements to label-/command set generators. Disabled MILP-call
 from main().

Former-commit-id: 7128ab44778272fb69995c84b8f5c16c38c003d1
---
 src/adapters/Z3ExpressionAdapter.h            | 206 ++++++++++++++++++
 .../MILPMinimalLabelSetGenerator.h            |   1 -
 .../SMTMinimalCommandSetGenerator.h           | 129 +++++++++--
 src/storm.cpp                                 |  14 +-
 4 files changed, 326 insertions(+), 24 deletions(-)
 create mode 100644 src/adapters/Z3ExpressionAdapter.h

diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h
new file mode 100644
index 000000000..c9d64e9d1
--- /dev/null
+++ b/src/adapters/Z3ExpressionAdapter.h
@@ -0,0 +1,206 @@
+/*
+ * Z3ExpressionAdapter.h
+ *
+ *  Created on: 04.10.2013
+ *      Author: Christian Dehnert
+ */
+
+#ifndef STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
+#define STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
+
+#include <stack>
+
+#include "src/ir/expressions/ExpressionVisitor.h"
+
+namespace storm {
+    namespace adapters {
+
+        class Z3ExpressionAdapter : public storm::ir::expressions::ExpressionVisitor {
+        public:
+            /*!
+             * Creates a Z3ExpressionAdapter over the given Z3 context.
+             *
+             * @param context The Z3 context over which to build the expressions.
+             */
+            Z3ExpressionAdapter(z3::context const& context, std::map<std::string, z3::expr> const& variableToExpressionMap) : context(context), stack(), variableToExpressionMap(variableToExpressionMap) {
+                // Intentionally left empty.
+            }
+            
+            /*!
+             * Translates the given expression to an equivalent expression for Z3.
+             *
+             * @param expression The expression to translate.
+             * @return An equivalent expression for Z3.
+             */
+            z3::expr translateExpression(std::shared_ptr<storm::ir::expressions::BaseExpression> expression) {
+                expression->accept(this);
+                return stack.top();
+            }
+            
+            virtual void visit(BinaryBooleanFunctionExpression* expression) {
+                expression->getLeft()->accept(this);
+                expression->getRight()->accept(this);
+                
+                z3::expr rightResult = stack.top();
+                stack.pop();
+                z3::expr leftResult = stack.top();
+                stack.pop();
+
+                switch(expression->getFunctionType()) {
+                    case storm::ir::expressions::BinaryBooleanFunctionExpression::AND:
+                        stack.push(leftResult && rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryBooleanFunctionExpression::OR:
+                        stack.push(leftResult || rightResult);
+                        break;
+                    default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+                        << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
+                }
+
+            }
+            
+            virtual void visit(BinaryNumericalFunctionExpression* expression) {
+                expression->getLeft()->accept(this);
+                expression->getRight()->accept(this);
+                
+                z3::expr rightResult = stack.top();
+                stack.pop();
+                z3::expr leftResult = stack.top();
+                stack.pop();
+                
+                switch(expression->getFunctionType()) {
+                    case storm::ir::expressions::BinaryNumericalFunctionExpression::PLUS:
+                        stack.push(leftResult + rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryNumericalFunctionExpression::MINUS:
+                        stack.push(leftResult - rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryNumericalFunctionExpression::TIMES:
+                        stack.push(leftResult * rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE:
+                        stack.push(leftResult / rightResult);
+                        break;
+                    default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+                        << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
+                }    
+            }
+            
+            virtual void visit(BinaryRelationExpression* expression) {
+                expression->getLeft()->accept(this);
+                expression->getRight()->accept(this);
+                
+                z3::expr rightResult = stack.top();
+                stack.pop();
+                z3::expr leftResult = stack.top();
+                stack.pop();
+                
+                switch(expression->getRelationType()) {
+                    case storm::ir::expressions::BinaryRelationExpression::EQUAL:
+                        stack.push(leftResult == rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryRelationExpression::NOT_EQUAL:
+                        stack.push(leftResult != rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryRelationExpression::LESS:
+                        stack.push(leftResult < rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryRelationExpression::LESS_OR_EQUAL:
+                        stack.push(leftResult <= rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryRelationExpression::GREATER:
+                        stack.push(leftResult > rightResult);
+                        break;
+                    case storm::ir::expressions::BinaryRelationExpression::GREATER_OR_EQUAL:
+                        stack.push(leftResult >= rightResult);
+                        break;
+                    default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+                        << "Unknown boolean binary operator: '" << expression->getRelationType() << "'.";
+                }    
+            }
+            
+            virtual void visit(BooleanConstantExpression* expression) {
+                if (!expression->isDefined()) {
+                    throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+					<< ". Boolean constant '" << expression->getConstantName() << "' is undefined.";
+                }
+                
+                stack.push(context.bool_val(expression->getValue()));    
+            }
+            
+            virtual void visit(BooleanLiteralExpression* expression) {
+                stack.push(context.bool_val(expression->getValueAsBool(nullptr))));
+            }
+            
+            virtual void visit(DoubleConstantExpression* expression) {
+                if (!expression->isDefined()) {
+                    throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+					<< ". Double constant '" << expression->getConstantName() << "' is undefined.";
+                }
+                
+                // FIXME: convert double value to suitable format.
+                stack.push(context.real_val(expression->getValue()));
+            }
+            
+            virtual void visit(DoubleLiteralExpression* expression) {
+                // FIXME: convert double value to suitable format.
+                stack.push(context.real_val(expression->getValue()));
+            }
+            
+            virtual void visit(IntegerConstantExpression* expression) {
+                if (!expression->isDefined()) {
+                    throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+					<< ". Integer constant '" << expression->getConstantName() << "' is undefined.";
+                }
+                
+                stack.push(context.int_val(expression->getValue()));    
+            }
+            
+            virtual void visit(IntegerLiteralExpression* expression) {
+                stack.push(context.int_val(expression->getValue()));    
+            }
+            
+            virtual void visit(UnaryBooleanFunctionExpression* expression) {
+                expression->getChild()->accept(this);
+                
+                z3::expr childResult = stack.top();
+                stack.pop();
+                
+                switch (expression->getFunctionType()) {
+                    case storm::ir::expressions::UnaryBooleanFunctionExpression::NOT:
+                        stack.push(!childResult);
+                        break;
+                    default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+                        << "Unknown boolean unary operator: '" << expression->getFunctionType() << "'.";
+                }    
+            }
+            
+            virtual void visit(UnaryNumericalFunctionExpression* expression) {
+                expression->getChild()->accept(this);
+                
+                z3::expr childResult = stack.top();
+                stack.pop();
+                
+                switch(expression->getFunctionType()) {
+                    case storm::ir::expressions::UnaryNumericalFunctionExpression::MINUS:
+                        stack.push(0 - childResult);
+                        break;
+                    default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
+                        << "Unknown numerical unary operator: '" << expression->getFunctionType() << "'.";
+                }
+            }
+            
+            virtual void visit(VariableExpression* expression) {
+                stack.push(variableToExpressionMap.at(expression->getVariableName());
+            }
+            
+        private:
+            z3::context context;
+            std::stack<z3::expr> stack;
+            std::map<std::string, z3::expr> variableToExpressionMap
+        }
+
+    } // namespace adapters
+} // namespace storm
+        
+#endif /* STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ */
diff --git a/src/counterexamples/MILPMinimalLabelSetGenerator.h b/src/counterexamples/MILPMinimalLabelSetGenerator.h
index a3387b174..3392e16d8 100644
--- a/src/counterexamples/MILPMinimalLabelSetGenerator.h
+++ b/src/counterexamples/MILPMinimalLabelSetGenerator.h
@@ -926,7 +926,6 @@ namespace storm {
                     for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) {
                         if (state != *predecessorIt) {
                             predecessors.insert(*predecessorIt);
-
                         }
                     }
                     
diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h
index dd96d07cc..2d6d6381f 100644
--- a/src/counterexamples/SMTMinimalCommandSetGenerator.h
+++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h
@@ -31,6 +31,12 @@ namespace storm {
         class SMTMinimalCommandSetGenerator {
 #ifdef STORM_HAVE_Z3
         private:
+            struct RelevancyInformation {
+                storm::storage::BitVector relevantStates;
+                std::set<uint_fast64_t> relevantLabels;
+                std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
+            };
+            
             struct VariableInformation {
                 std::vector<z3::expr> labelVariables;
                 std::vector<z3::expr> auxiliaryVariables;
@@ -44,17 +50,19 @@ namespace storm {
              * @param labeledMdp The MDP to search for relevant labels.
              * @param phiStates A bit vector representing all states that satisfy phi.
              * @param psiStates A bit vector representing all states that satisfy psi.
-             * @return A set of relevant labels, where relevant is defined as above.
+             * @return A structure containing the relevant labels as well as states.
              */
-            static std::set<uint_fast64_t> getRelevantLabels(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) {
+            static RelevancyInformation determineRelevantStatesAndLabels(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) {
                 // Create result.
-                std::set<uint_fast64_t> relevantLabels;
+                RelevancyInformation relevancyInformation;
                 
                 // Compute all relevant states, i.e. states for which there exists a scheduler that has a non-zero
                 // probabilitiy of satisfying phi until psi.
                 storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
-                storm::storage::BitVector relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates);
-                relevantStates &= ~psiStates;
+                relevancyInformation.relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates);
+                relevancyInformation.relevantStates &= ~psiStates;
+
+                LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantStates.getNumberOfSetBits() << " relevant states.");
 
                 // Retrieve some references for convenient access.
                 storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
@@ -64,21 +72,29 @@ namespace storm {
                 // Now traverse all choices of all relevant states and check whether there is a successor target state.
                 // If so, the associated labels become relevant. Also, if a choice of relevant state has at least one
                 // relevant successor, the choice becomes relevant.
-                for (auto state : relevantStates) {
+                for (auto state : relevancyInformation.relevantStates) {
+                    relevancyInformation.relevantChoicesForRelevantStates.emplace(state, std::list<uint_fast64_t>());
+                    
                     for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row) {
+                        bool currentChoiceRelevant = false;
+
                         for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) {
                             // If there is a relevant successor, we need to add the labels of the current choice.
-                            if (relevantStates.get(*successorIt) || psiStates.get(*successorIt)) {
+                            if (relevancyInformation.relevantStates.get(*successorIt) || psiStates.get(*successorIt)) {
                                 for (auto const& label : choiceLabeling[row]) {
-                                    relevantLabels.insert(label);
+                                    relevancyInformation.relevantLabels.insert(label);
+                                }
+                                if (!currentChoiceRelevant) {
+                                    currentChoiceRelevant = true;
+                                    relevancyInformation.relevantChoicesForRelevantStates[state].push_back(row);
                                 }
                             }
                         }
                     }
                 }
                 
-                LOG4CPLUS_DEBUG(logger, "Found " << relevantLabels.size() << " relevant labels.");
-                return relevantLabels;
+                LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantLabels.size() << " relevant labels.");
+                return relevancyInformation;
             }
             
             /*!
@@ -119,11 +135,12 @@ namespace storm {
              * Asserts the constraints that are initially known.
              *
              * @param program The program for which to build the constraints.
+             * @param labeledMdp The MDP that results from the given program.
              * @param context The Z3 context in which to build the expressions.
              * @param solver The solver in which to assert the constraints.
              * @param variableInformation A structure with information about the variables for the labels.
              */
-            static void assertInitialConstraints(storm::ir::Program const& program, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation) {
+            static void assertInitialConstraints(storm::ir::Program const& program, storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation) {
                 // Assert that at least one of the labels must be taken.
                 z3::expr formula = variableInformation.labelVariables.at(0);
                 for (uint_fast64_t index = 1; index < variableInformation.labelVariables.size(); ++index) {
@@ -134,8 +151,84 @@ namespace storm {
                 for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) {
                     solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]);
                 }
+                
+                std::vector<std::set<uint_fast64_t>> const& choiceLabeling = labeledMdp.getChoiceLabeling();
+                storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
+
+                // Assert that at least one of the labels of one of the relevant initial states is taken.
+                std::vector<z3::expr> expressionVector;
+                bool firstAssignment = true;
+                for (auto state : labeledMdp.getInitialStates()) {
+                    if (relevancyInformation.relevantStates.get(state)) {
+                        for (auto const& choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
+                            for (auto const& label : choiceLabeling[choice]) {
+                                z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
+                                if (firstAssignment) {
+                                    expressionVector.push_back(labelExpression);
+                                    firstAssignment = false;
+                                } else {
+                                    expressionVector.back() = expressionVector.back() && labelExpression;
+                                }
+                            }
+                        }
+                    }
+                }
+                assertDisjunction(context, solver, expressionVector);
+                
+                // Assert that at least one of the labels that are selected can reach a target state in one step.
+                storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
+
+                // Compute the set of predecessors of target states.
+                std::unordered_set<uint_fast64_t> predecessors;
+                for (auto state : psiStates) {
+                    for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) {
+                        if (state != *predecessorIt) {
+                            predecessors.insert(*predecessorIt);
+                        }
+                    }
+                }
+
+                expressionVector.clear();
+                firstAssignment = true;
+                for (auto state : predecessors) {
+                    for (auto choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
+                        for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(choice); successorIt != transitionMatrix.constColumnIteratorEnd(choice); ++successorIt) {
+                            if (psiStates.get(*successorIt)) {
+                                for (auto const& label : choiceLabeling[choice]) {
+                                    z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
+                                    if (firstAssignment) {
+                                        expressionVector.push_back(labelExpression);
+                                        firstAssignment = false;
+                                    } else {
+                                        expressionVector.back() = expressionVector.back() && labelExpression;
+                                    }
+                                }
+                            }
+                        }
+                    }
+                }
+                assertDisjunction(context, solver, expressionVector);
             }
 
+            /*!
+             * Asserts that the disjunction of the given formulae holds.
+             *
+             * @param context The Z3 context in which to build the expressions.
+             * @param solver The solver to use for the satisfiability evaluation.
+             * @param formulaVector A vector of expressions that shall form the disjunction.
+             */
+            static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) {
+                z3::expr disjunction(context);
+                for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) {
+                    if (i == 0) {
+                        disjunction = formulaVector[i];
+                    } else {
+                        disjunction = disjunction || formulaVector[i];
+                    }
+                }
+                solver.add(disjunction);
+            }
+            
             /*!
              * Asserts that at most one of the blocking variables may be true at any time.
              *
@@ -288,20 +381,20 @@ namespace storm {
                 
                 // (1) FIXME: check whether its possible to exceed the threshold if checkThresholdFeasible is set.
 
-                // (2) Identify all commands that are relevant, because only these need to be considered later.
-                std::set<uint_fast64_t> relevantCommands = getRelevantLabels(labeledMdp, phiStates, psiStates);
+                // (2) Identify all states and commands that are relevant, because only these need to be considered later.
+                RelevancyInformation relevancyInformation = determineRelevantStatesAndLabels(labeledMdp, phiStates, psiStates);
                 
                 // (3) Create context for solver.
                 z3::context context;
                 
                 // (4) Create the variables for the relevant commands.
-                VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevantCommands);
+                VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevancyInformation.relevantLabels);
                 
                 // (5) After all variables have been created, create a solver for that context.
                 z3::solver solver(context);
 
                 // (5) Build the initial constraint system.
-                assertInitialConstraints(program, context, solver, variableInformation);
+                assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation);
                 
                 // (6) Find the smallest set of commands that satisfies all constraints. If the probability of
                 // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned.
@@ -323,13 +416,16 @@ namespace storm {
                 std::set<uint_fast64_t> commandSet;
                 double maximalReachabilityProbability = 0;
                 bool done = false;
+                uint_fast64_t iterations = 0;
                 do {
                     commandSet = findSmallestCommandSet(context, solver, variableInformation, softConstraints, nextFreeVariableIndex);
                     
                     // Restrict the given MDP to the current set of labels and compute the reachability probability.
                     storm::models::Mdp<T> subMdp = labeledMdp.restrictChoiceLabels(commandSet);
                     storm::modelchecker::prctl::SparseMdpPrctlModelChecker<T> modelchecker(subMdp, new storm::solver::GmmxxNondeterministicLinearEquationSolver<T>());
+                    LOG4CPLUS_DEBUG(logger, "Invoking model checker.");
                     std::vector<T> result = modelchecker.checkUntil(false, phiStates, psiStates, false, nullptr);
+                    LOG4CPLUS_DEBUG(logger, "Computed model checking results.");
                     
                     // Now determine the maximal reachability probability by checking all initial states.
                     for (auto state : labeledMdp.getInitialStates()) {
@@ -342,7 +438,8 @@ namespace storm {
                     } else {
                         done = true;
                     }
-                    std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl;
+                    std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands in iteration " << iterations << "." << std::endl;
+                    ++iterations;
                 } while (!done);
                 
                 std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl;
diff --git a/src/storm.cpp b/src/storm.cpp
index 9543930da..ad9985d60 100644
--- a/src/storm.cpp
+++ b/src/storm.cpp
@@ -338,13 +338,13 @@ int main(const int argc, const char* argv[]) {
 			model->printModelInformationToStream(std::cout);
 
             // Enable the following lines to test the MinimalLabelSetGenerator.
-            if (model->getType() == storm::models::MDP) {
-                std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
-                storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
-                storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
-                storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
-                storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
-            }
+//            if (model->getType() == storm::models::MDP) {
+//                std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
+//                storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
+//                storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
+//                storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
+//                storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
+//            }
             
             // Enable the following lines to test the SMTMinimalCommandSetGenerator.
             if (model->getType() == storm::models::MDP) {