Browse Source

Added Z3ExpressionAdapter to translate IR expressions to the Z3 format. Improvements to label-/command set generators. Disabled MILP-call from main().

Former-commit-id: 7128ab4477
tempestpy_adaptions
dehnert 11 years ago
parent
commit
2cc5b6e080
  1. 206
      src/adapters/Z3ExpressionAdapter.h
  2. 1
      src/counterexamples/MILPMinimalLabelSetGenerator.h
  3. 129
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  4. 14
      src/storm.cpp

206
src/adapters/Z3ExpressionAdapter.h

@ -0,0 +1,206 @@
/*
* Z3ExpressionAdapter.h
*
* Created on: 04.10.2013
* Author: Christian Dehnert
*/
#ifndef STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
#define STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
#include <stack>
#include "src/ir/expressions/ExpressionVisitor.h"
namespace storm {
namespace adapters {
class Z3ExpressionAdapter : public storm::ir::expressions::ExpressionVisitor {
public:
/*!
* Creates a Z3ExpressionAdapter over the given Z3 context.
*
* @param context The Z3 context over which to build the expressions.
*/
Z3ExpressionAdapter(z3::context const& context, std::map<std::string, z3::expr> const& variableToExpressionMap) : context(context), stack(), variableToExpressionMap(variableToExpressionMap) {
// Intentionally left empty.
}
/*!
* Translates the given expression to an equivalent expression for Z3.
*
* @param expression The expression to translate.
* @return An equivalent expression for Z3.
*/
z3::expr translateExpression(std::shared_ptr<storm::ir::expressions::BaseExpression> expression) {
expression->accept(this);
return stack.top();
}
virtual void visit(BinaryBooleanFunctionExpression* expression) {
expression->getLeft()->accept(this);
expression->getRight()->accept(this);
z3::expr rightResult = stack.top();
stack.pop();
z3::expr leftResult = stack.top();
stack.pop();
switch(expression->getFunctionType()) {
case storm::ir::expressions::BinaryBooleanFunctionExpression::AND:
stack.push(leftResult && rightResult);
break;
case storm::ir::expressions::BinaryBooleanFunctionExpression::OR:
stack.push(leftResult || rightResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
}
}
virtual void visit(BinaryNumericalFunctionExpression* expression) {
expression->getLeft()->accept(this);
expression->getRight()->accept(this);
z3::expr rightResult = stack.top();
stack.pop();
z3::expr leftResult = stack.top();
stack.pop();
switch(expression->getFunctionType()) {
case storm::ir::expressions::BinaryNumericalFunctionExpression::PLUS:
stack.push(leftResult + rightResult);
break;
case storm::ir::expressions::BinaryNumericalFunctionExpression::MINUS:
stack.push(leftResult - rightResult);
break;
case storm::ir::expressions::BinaryNumericalFunctionExpression::TIMES:
stack.push(leftResult * rightResult);
break;
case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE:
stack.push(leftResult / rightResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
}
}
virtual void visit(BinaryRelationExpression* expression) {
expression->getLeft()->accept(this);
expression->getRight()->accept(this);
z3::expr rightResult = stack.top();
stack.pop();
z3::expr leftResult = stack.top();
stack.pop();
switch(expression->getRelationType()) {
case storm::ir::expressions::BinaryRelationExpression::EQUAL:
stack.push(leftResult == rightResult);
break;
case storm::ir::expressions::BinaryRelationExpression::NOT_EQUAL:
stack.push(leftResult != rightResult);
break;
case storm::ir::expressions::BinaryRelationExpression::LESS:
stack.push(leftResult < rightResult);
break;
case storm::ir::expressions::BinaryRelationExpression::LESS_OR_EQUAL:
stack.push(leftResult <= rightResult);
break;
case storm::ir::expressions::BinaryRelationExpression::GREATER:
stack.push(leftResult > rightResult);
break;
case storm::ir::expressions::BinaryRelationExpression::GREATER_OR_EQUAL:
stack.push(leftResult >= rightResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getRelationType() << "'.";
}
}
virtual void visit(BooleanConstantExpression* expression) {
if (!expression->isDefined()) {
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< ". Boolean constant '" << expression->getConstantName() << "' is undefined.";
}
stack.push(context.bool_val(expression->getValue()));
}
virtual void visit(BooleanLiteralExpression* expression) {
stack.push(context.bool_val(expression->getValueAsBool(nullptr))));
}
virtual void visit(DoubleConstantExpression* expression) {
if (!expression->isDefined()) {
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< ". Double constant '" << expression->getConstantName() << "' is undefined.";
}
// FIXME: convert double value to suitable format.
stack.push(context.real_val(expression->getValue()));
}
virtual void visit(DoubleLiteralExpression* expression) {
// FIXME: convert double value to suitable format.
stack.push(context.real_val(expression->getValue()));
}
virtual void visit(IntegerConstantExpression* expression) {
if (!expression->isDefined()) {
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< ". Integer constant '" << expression->getConstantName() << "' is undefined.";
}
stack.push(context.int_val(expression->getValue()));
}
virtual void visit(IntegerLiteralExpression* expression) {
stack.push(context.int_val(expression->getValue()));
}
virtual void visit(UnaryBooleanFunctionExpression* expression) {
expression->getChild()->accept(this);
z3::expr childResult = stack.top();
stack.pop();
switch (expression->getFunctionType()) {
case storm::ir::expressions::UnaryBooleanFunctionExpression::NOT:
stack.push(!childResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean unary operator: '" << expression->getFunctionType() << "'.";
}
}
virtual void visit(UnaryNumericalFunctionExpression* expression) {
expression->getChild()->accept(this);
z3::expr childResult = stack.top();
stack.pop();
switch(expression->getFunctionType()) {
case storm::ir::expressions::UnaryNumericalFunctionExpression::MINUS:
stack.push(0 - childResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown numerical unary operator: '" << expression->getFunctionType() << "'.";
}
}
virtual void visit(VariableExpression* expression) {
stack.push(variableToExpressionMap.at(expression->getVariableName());
}
private:
z3::context context;
std::stack<z3::expr> stack;
std::map<std::string, z3::expr> variableToExpressionMap
}
} // namespace adapters
} // namespace storm
#endif /* STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_ */

1
src/counterexamples/MILPMinimalLabelSetGenerator.h

@ -926,7 +926,6 @@ namespace storm {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) {
if (state != *predecessorIt) {
predecessors.insert(*predecessorIt);
}
}

129
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -31,6 +31,12 @@ namespace storm {
class SMTMinimalCommandSetGenerator {
#ifdef STORM_HAVE_Z3
private:
struct RelevancyInformation {
storm::storage::BitVector relevantStates;
std::set<uint_fast64_t> relevantLabels;
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
};
struct VariableInformation {
std::vector<z3::expr> labelVariables;
std::vector<z3::expr> auxiliaryVariables;
@ -44,17 +50,19 @@ namespace storm {
* @param labeledMdp The MDP to search for relevant labels.
* @param phiStates A bit vector representing all states that satisfy phi.
* @param psiStates A bit vector representing all states that satisfy psi.
* @return A set of relevant labels, where relevant is defined as above.
* @return A structure containing the relevant labels as well as states.
*/
static std::set<uint_fast64_t> getRelevantLabels(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) {
static RelevancyInformation determineRelevantStatesAndLabels(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) {
// Create result.
std::set<uint_fast64_t> relevantLabels;
RelevancyInformation relevancyInformation;
// Compute all relevant states, i.e. states for which there exists a scheduler that has a non-zero
// probabilitiy of satisfying phi until psi.
storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
storm::storage::BitVector relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates);
relevantStates &= ~psiStates;
relevancyInformation.relevantStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates);
relevancyInformation.relevantStates &= ~psiStates;
LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantStates.getNumberOfSetBits() << " relevant states.");
// Retrieve some references for convenient access.
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
@ -64,21 +72,29 @@ namespace storm {
// Now traverse all choices of all relevant states and check whether there is a successor target state.
// If so, the associated labels become relevant. Also, if a choice of relevant state has at least one
// relevant successor, the choice becomes relevant.
for (auto state : relevantStates) {
for (auto state : relevancyInformation.relevantStates) {
relevancyInformation.relevantChoicesForRelevantStates.emplace(state, std::list<uint_fast64_t>());
for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row) {
bool currentChoiceRelevant = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) {
// If there is a relevant successor, we need to add the labels of the current choice.
if (relevantStates.get(*successorIt) || psiStates.get(*successorIt)) {
if (relevancyInformation.relevantStates.get(*successorIt) || psiStates.get(*successorIt)) {
for (auto const& label : choiceLabeling[row]) {
relevantLabels.insert(label);
relevancyInformation.relevantLabels.insert(label);
}
if (!currentChoiceRelevant) {
currentChoiceRelevant = true;
relevancyInformation.relevantChoicesForRelevantStates[state].push_back(row);
}
}
}
}
}
LOG4CPLUS_DEBUG(logger, "Found " << relevantLabels.size() << " relevant labels.");
return relevantLabels;
LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantLabels.size() << " relevant labels.");
return relevancyInformation;
}
/*!
@ -119,11 +135,12 @@ namespace storm {
* Asserts the constraints that are initially known.
*
* @param program The program for which to build the constraints.
* @param labeledMdp The MDP that results from the given program.
* @param context The Z3 context in which to build the expressions.
* @param solver The solver in which to assert the constraints.
* @param variableInformation A structure with information about the variables for the labels.
*/
static void assertInitialConstraints(storm::ir::Program const& program, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation) {
static void assertInitialConstraints(storm::ir::Program const& program, storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation) {
// Assert that at least one of the labels must be taken.
z3::expr formula = variableInformation.labelVariables.at(0);
for (uint_fast64_t index = 1; index < variableInformation.labelVariables.size(); ++index) {
@ -134,8 +151,84 @@ namespace storm {
for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) {
solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]);
}
std::vector<std::set<uint_fast64_t>> const& choiceLabeling = labeledMdp.getChoiceLabeling();
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
// Assert that at least one of the labels of one of the relevant initial states is taken.
std::vector<z3::expr> expressionVector;
bool firstAssignment = true;
for (auto state : labeledMdp.getInitialStates()) {
if (relevancyInformation.relevantStates.get(state)) {
for (auto const& choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
for (auto const& label : choiceLabeling[choice]) {
z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
if (firstAssignment) {
expressionVector.push_back(labelExpression);
firstAssignment = false;
} else {
expressionVector.back() = expressionVector.back() && labelExpression;
}
}
}
}
}
assertDisjunction(context, solver, expressionVector);
// Assert that at least one of the labels that are selected can reach a target state in one step.
storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
// Compute the set of predecessors of target states.
std::unordered_set<uint_fast64_t> predecessors;
for (auto state : psiStates) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) {
if (state != *predecessorIt) {
predecessors.insert(*predecessorIt);
}
}
}
expressionVector.clear();
firstAssignment = true;
for (auto state : predecessors) {
for (auto choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(choice); successorIt != transitionMatrix.constColumnIteratorEnd(choice); ++successorIt) {
if (psiStates.get(*successorIt)) {
for (auto const& label : choiceLabeling[choice]) {
z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
if (firstAssignment) {
expressionVector.push_back(labelExpression);
firstAssignment = false;
} else {
expressionVector.back() = expressionVector.back() && labelExpression;
}
}
}
}
}
}
assertDisjunction(context, solver, expressionVector);
}
/*!
* Asserts that the disjunction of the given formulae holds.
*
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
* @param formulaVector A vector of expressions that shall form the disjunction.
*/
static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) {
z3::expr disjunction(context);
for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) {
if (i == 0) {
disjunction = formulaVector[i];
} else {
disjunction = disjunction || formulaVector[i];
}
}
solver.add(disjunction);
}
/*!
* Asserts that at most one of the blocking variables may be true at any time.
*
@ -288,20 +381,20 @@ namespace storm {
// (1) FIXME: check whether its possible to exceed the threshold if checkThresholdFeasible is set.
// (2) Identify all commands that are relevant, because only these need to be considered later.
std::set<uint_fast64_t> relevantCommands = getRelevantLabels(labeledMdp, phiStates, psiStates);
// (2) Identify all states and commands that are relevant, because only these need to be considered later.
RelevancyInformation relevancyInformation = determineRelevantStatesAndLabels(labeledMdp, phiStates, psiStates);
// (3) Create context for solver.
z3::context context;
// (4) Create the variables for the relevant commands.
VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevantCommands);
VariableInformation variableInformation = createExpressionsForRelevantLabels(context, relevancyInformation.relevantLabels);
// (5) After all variables have been created, create a solver for that context.
z3::solver solver(context);
// (5) Build the initial constraint system.
assertInitialConstraints(program, context, solver, variableInformation);
assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation);
// (6) Find the smallest set of commands that satisfies all constraints. If the probability of
// satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned.
@ -323,13 +416,16 @@ namespace storm {
std::set<uint_fast64_t> commandSet;
double maximalReachabilityProbability = 0;
bool done = false;
uint_fast64_t iterations = 0;
do {
commandSet = findSmallestCommandSet(context, solver, variableInformation, softConstraints, nextFreeVariableIndex);
// Restrict the given MDP to the current set of labels and compute the reachability probability.
storm::models::Mdp<T> subMdp = labeledMdp.restrictChoiceLabels(commandSet);
storm::modelchecker::prctl::SparseMdpPrctlModelChecker<T> modelchecker(subMdp, new storm::solver::GmmxxNondeterministicLinearEquationSolver<T>());
LOG4CPLUS_DEBUG(logger, "Invoking model checker.");
std::vector<T> result = modelchecker.checkUntil(false, phiStates, psiStates, false, nullptr);
LOG4CPLUS_DEBUG(logger, "Computed model checking results.");
// Now determine the maximal reachability probability by checking all initial states.
for (auto state : labeledMdp.getInitialStates()) {
@ -342,7 +438,8 @@ namespace storm {
} else {
done = true;
}
std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl;
std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands in iteration " << iterations << "." << std::endl;
++iterations;
} while (!done);
std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl;

14
src/storm.cpp

@ -338,13 +338,13 @@ int main(const int argc, const char* argv[]) {
model->printModelInformationToStream(std::cout);
// Enable the following lines to test the MinimalLabelSetGenerator.
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
}
// if (model->getType() == storm::models::MDP) {
// std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
// }
// Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) {

Loading…
Cancel
Save