diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h new file mode 100644 index 000000000..c9d64e9d1 --- /dev/null +++ b/src/adapters/Z3ExpressionAdapter.h @@ -0,0 +1,206 @@ +/* + * Z3ExpressionAdapter.h + * + * Created on: 04.10.2013 + * Author: Christian Dehnert + */ + +#ifndef STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ +#define STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ + +#include + +#include "src/ir/expressions/ExpressionVisitor.h" + +namespace storm { + namespace adapters { + + class Z3ExpressionAdapter : public storm::ir::expressions::ExpressionVisitor { + public: + /*! + * Creates a Z3ExpressionAdapter over the given Z3 context. + * + * @param context The Z3 context over which to build the expressions. + */ + Z3ExpressionAdapter(z3::context const& context, std::map const& variableToExpressionMap) : context(context), stack(), variableToExpressionMap(variableToExpressionMap) { + // Intentionally left empty. + } + + /*! + * Translates the given expression to an equivalent expression for Z3. + * + * @param expression The expression to translate. + * @return An equivalent expression for Z3. + */ + z3::expr translateExpression(std::shared_ptr expression) { + expression->accept(this); + return stack.top(); + } + + virtual void visit(BinaryBooleanFunctionExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + z3::expr rightResult = stack.top(); + stack.pop(); + z3::expr leftResult = stack.top(); + stack.pop(); + + switch(expression->getFunctionType()) { + case storm::ir::expressions::BinaryBooleanFunctionExpression::AND: + stack.push(leftResult && rightResult); + break; + case storm::ir::expressions::BinaryBooleanFunctionExpression::OR: + stack.push(leftResult || rightResult); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'."; + } + + } + + virtual void visit(BinaryNumericalFunctionExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + z3::expr rightResult = stack.top(); + stack.pop(); + z3::expr leftResult = stack.top(); + stack.pop(); + + switch(expression->getFunctionType()) { + case storm::ir::expressions::BinaryNumericalFunctionExpression::PLUS: + stack.push(leftResult + rightResult); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::MINUS: + stack.push(leftResult - rightResult); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::TIMES: + stack.push(leftResult * rightResult); + break; + case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE: + stack.push(leftResult / rightResult); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getFunctionType() << "'."; + } + } + + virtual void visit(BinaryRelationExpression* expression) { + expression->getLeft()->accept(this); + expression->getRight()->accept(this); + + z3::expr rightResult = stack.top(); + stack.pop(); + z3::expr leftResult = stack.top(); + stack.pop(); + + switch(expression->getRelationType()) { + case storm::ir::expressions::BinaryRelationExpression::EQUAL: + stack.push(leftResult == rightResult); + break; + case storm::ir::expressions::BinaryRelationExpression::NOT_EQUAL: + stack.push(leftResult != rightResult); + break; + case storm::ir::expressions::BinaryRelationExpression::LESS: + stack.push(leftResult < rightResult); + break; + case storm::ir::expressions::BinaryRelationExpression::LESS_OR_EQUAL: + stack.push(leftResult <= rightResult); + break; + case storm::ir::expressions::BinaryRelationExpression::GREATER: + stack.push(leftResult > rightResult); + break; + case storm::ir::expressions::BinaryRelationExpression::GREATER_OR_EQUAL: + stack.push(leftResult >= rightResult); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean binary operator: '" << expression->getRelationType() << "'."; + } + } + + virtual void visit(BooleanConstantExpression* expression) { + if (!expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << ". Boolean constant '" << expression->getConstantName() << "' is undefined."; + } + + stack.push(context.bool_val(expression->getValue())); + } + + virtual void visit(BooleanLiteralExpression* expression) { + stack.push(context.bool_val(expression->getValueAsBool(nullptr)))); + } + + virtual void visit(DoubleConstantExpression* expression) { + if (!expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << ". Double constant '" << expression->getConstantName() << "' is undefined."; + } + + // FIXME: convert double value to suitable format. + stack.push(context.real_val(expression->getValue())); + } + + virtual void visit(DoubleLiteralExpression* expression) { + // FIXME: convert double value to suitable format. + stack.push(context.real_val(expression->getValue())); + } + + virtual void visit(IntegerConstantExpression* expression) { + if (!expression->isDefined()) { + throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << ". Integer constant '" << expression->getConstantName() << "' is undefined."; + } + + stack.push(context.int_val(expression->getValue())); + } + + virtual void visit(IntegerLiteralExpression* expression) { + stack.push(context.int_val(expression->getValue())); + } + + virtual void visit(UnaryBooleanFunctionExpression* expression) { + expression->getChild()->accept(this); + + z3::expr childResult = stack.top(); + stack.pop(); + + switch (expression->getFunctionType()) { + case storm::ir::expressions::UnaryBooleanFunctionExpression::NOT: + stack.push(!childResult); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown boolean unary operator: '" << expression->getFunctionType() << "'."; + } + } + + virtual void visit(UnaryNumericalFunctionExpression* expression) { + expression->getChild()->accept(this); + + z3::expr childResult = stack.top(); + stack.pop(); + + switch(expression->getFunctionType()) { + case storm::ir::expressions::UnaryNumericalFunctionExpression::MINUS: + stack.push(0 - childResult); + break; + default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " + << "Unknown numerical unary operator: '" << expression->getFunctionType() << "'."; + } + } + + virtual void visit(VariableExpression* expression) { + stack.push(variableToExpressionMap.at(expression->getVariableName()); + } + + private: + z3::context context; + std::stack stack; + std::map variableToExpressionMap + } + + } // namespace adapters +} // namespace storm + +#endif /* STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ */ diff --git a/src/counterexamples/MILPMinimalLabelSetGenerator.h b/src/counterexamples/MILPMinimalLabelSetGenerator.h index a3387b174..3392e16d8 100644 --- a/src/counterexamples/MILPMinimalLabelSetGenerator.h +++ b/src/counterexamples/MILPMinimalLabelSetGenerator.h @@ -926,7 +926,6 @@ namespace storm { for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) { if (state != *predecessorIt) { predecessors.insert(*predecessorIt); - } } diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index dd96d07cc..2d6d6381f 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -31,6 +31,12 @@ namespace storm { class SMTMinimalCommandSetGenerator { #ifdef STORM_HAVE_Z3 private: + struct RelevancyInformation { + storm::storage::BitVector relevantStates; + std::set relevantLabels; + std::unordered_map> relevantChoicesForRelevantStates; + }; + struct VariableInformation { std::vector labelVariables; std::vector auxiliaryVariables; @@ -44,17 +50,19 @@ namespace storm { * @param labeledMdp The MDP to search for relevant labels. * @param phiStates A bit vector representing all states that satisfy phi. * @param psiStates A bit vector representing all states that satisfy psi. - * @return A set of relevant labels, where relevant is defined as above. + * @return A structure containing the relevant labels as well as states. */ - static std::set getRelevantLabels(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { + static RelevancyInformation determineRelevantStatesAndLabels(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { // Create result. - std::set relevantLabels; + RelevancyInformation relevancyInformation; // Compute all relevant states, i.e. states for which there exists a scheduler that has a non-zero // probabilitiy of satisfying phi until psi. storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); - storm::storage::BitVector relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates); - relevantStates &= ~psiStates; + relevancyInformation.relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates); + relevancyInformation.relevantStates &= ~psiStates; + + LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantStates.getNumberOfSetBits() << " relevant states."); // Retrieve some references for convenient access. storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); @@ -64,21 +72,29 @@ namespace storm { // Now traverse all choices of all relevant states and check whether there is a successor target state. // If so, the associated labels become relevant. Also, if a choice of relevant state has at least one // relevant successor, the choice becomes relevant. - for (auto state : relevantStates) { + for (auto state : relevancyInformation.relevantStates) { + relevancyInformation.relevantChoicesForRelevantStates.emplace(state, std::list()); + for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row) { + bool currentChoiceRelevant = false; + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) { // If there is a relevant successor, we need to add the labels of the current choice. - if (relevantStates.get(*successorIt) || psiStates.get(*successorIt)) { + if (relevancyInformation.relevantStates.get(*successorIt) || psiStates.get(*successorIt)) { for (auto const& label : choiceLabeling[row]) { - relevantLabels.insert(label); + relevancyInformation.relevantLabels.insert(label); + } + if (!currentChoiceRelevant) { + currentChoiceRelevant = true; + relevancyInformation.relevantChoicesForRelevantStates[state].push_back(row); } } } } } - LOG4CPLUS_DEBUG(logger, "Found " << relevantLabels.size() << " relevant labels."); - return relevantLabels; + LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantLabels.size() << " relevant labels."); + return relevancyInformation; } /*! @@ -119,11 +135,12 @@ namespace storm { * Asserts the constraints that are initially known. * * @param program The program for which to build the constraints. + * @param labeledMdp The MDP that results from the given program. * @param context The Z3 context in which to build the expressions. * @param solver The solver in which to assert the constraints. * @param variableInformation A structure with information about the variables for the labels. */ - static void assertInitialConstraints(storm::ir::Program const& program, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation) { + static void assertInitialConstraints(storm::ir::Program const& program, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation) { // Assert that at least one of the labels must be taken. z3::expr formula = variableInformation.labelVariables.at(0); for (uint_fast64_t index = 1; index < variableInformation.labelVariables.size(); ++index) { @@ -134,8 +151,84 @@ namespace storm { for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) { solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]); } + + std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); + storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); + + // Assert that at least one of the labels of one of the relevant initial states is taken. + std::vector expressionVector; + bool firstAssignment = true; + for (auto state : labeledMdp.getInitialStates()) { + if (relevancyInformation.relevantStates.get(state)) { + for (auto const& choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { + for (auto const& label : choiceLabeling[choice]) { + z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); + if (firstAssignment) { + expressionVector.push_back(labelExpression); + firstAssignment = false; + } else { + expressionVector.back() = expressionVector.back() && labelExpression; + } + } + } + } + } + assertDisjunction(context, solver, expressionVector); + + // Assert that at least one of the labels that are selected can reach a target state in one step. + storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); + + // Compute the set of predecessors of target states. + std::unordered_set predecessors; + for (auto state : psiStates) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) { + if (state != *predecessorIt) { + predecessors.insert(*predecessorIt); + } + } + } + + expressionVector.clear(); + firstAssignment = true; + for (auto state : predecessors) { + for (auto choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(choice); successorIt != transitionMatrix.constColumnIteratorEnd(choice); ++successorIt) { + if (psiStates.get(*successorIt)) { + for (auto const& label : choiceLabeling[choice]) { + z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); + if (firstAssignment) { + expressionVector.push_back(labelExpression); + firstAssignment = false; + } else { + expressionVector.back() = expressionVector.back() && labelExpression; + } + } + } + } + } + } + assertDisjunction(context, solver, expressionVector); } + /*! + * Asserts that the disjunction of the given formulae holds. + * + * @param context The Z3 context in which to build the expressions. + * @param solver The solver to use for the satisfiability evaluation. + * @param formulaVector A vector of expressions that shall form the disjunction. + */ + static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector const& formulaVector) { + z3::expr disjunction(context); + for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) { + if (i == 0) { + disjunction = formulaVector[i]; + } else { + disjunction = disjunction || formulaVector[i]; + } + } + solver.add(disjunction); + } + /*! * Asserts that at most one of the blocking variables may be true at any time. * @@ -288,20 +381,20 @@ namespace storm { // (1) FIXME: check whether its possible to exceed the threshold if checkThresholdFeasible is set. - // (2) Identify all commands that are relevant, because only these need to be considered later. - std::set relevantCommands = getRelevantLabels(labeledMdp, phiStates, psiStates); + // (2) Identify all states and commands that are relevant, because only these need to be considered later. + RelevancyInformation relevancyInformation = determineRelevantStatesAndLabels(labeledMdp, phiStates, psiStates); // (3) Create context for solver. z3::context context; // (4) Create the variables for the relevant commands. - VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevantCommands); + VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevancyInformation.relevantLabels); // (5) After all variables have been created, create a solver for that context. z3::solver solver(context); // (5) Build the initial constraint system. - assertInitialConstraints(program, context, solver, variableInformation); + assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation); // (6) Find the smallest set of commands that satisfies all constraints. If the probability of // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned. @@ -323,13 +416,16 @@ namespace storm { std::set commandSet; double maximalReachabilityProbability = 0; bool done = false; + uint_fast64_t iterations = 0; do { commandSet = findSmallestCommandSet(context, solver, variableInformation, softConstraints, nextFreeVariableIndex); // Restrict the given MDP to the current set of labels and compute the reachability probability. storm::models::Mdp subMdp = labeledMdp.restrictChoiceLabels(commandSet); storm::modelchecker::prctl::SparseMdpPrctlModelChecker modelchecker(subMdp, new storm::solver::GmmxxNondeterministicLinearEquationSolver()); + LOG4CPLUS_DEBUG(logger, "Invoking model checker."); std::vector result = modelchecker.checkUntil(false, phiStates, psiStates, false, nullptr); + LOG4CPLUS_DEBUG(logger, "Computed model checking results."); // Now determine the maximal reachability probability by checking all initial states. for (auto state : labeledMdp.getInitialStates()) { @@ -342,7 +438,8 @@ namespace storm { } else { done = true; } - std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl; + std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands in iteration " << iterations << "." << std::endl; + ++iterations; } while (!done); std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl; diff --git a/src/storm.cpp b/src/storm.cpp index 9543930da..ad9985d60 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -338,13 +338,13 @@ int main(const int argc, const char* argv[]) { model->printModelInformationToStream(std::cout); // Enable the following lines to test the MinimalLabelSetGenerator. - if (model->getType() == storm::models::MDP) { - std::shared_ptr> labeledMdp = model->as>(); - storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); - storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); - storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; - storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); - } +// if (model->getType() == storm::models::MDP) { +// std::shared_ptr> labeledMdp = model->as>(); +// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); +// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); +// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; +// storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); +// } // Enable the following lines to test the SMTMinimalCommandSetGenerator. if (model->getType() == storm::models::MDP) {