Browse Source

added cache to Z3 expression translator to speed up the translation of large constraints

main
dehnert 7 years ago
parent
commit
9dea83055b
  1. 158
      src/storm/adapters/Z3ExpressionAdapter.cpp
  2. 7
      src/storm/adapters/Z3ExpressionAdapter.h
  3. 22
      src/storm/counterexamples/SMTMinimalLabelSetGenerator.h

158
src/storm/adapters/Z3ExpressionAdapter.cpp

@ -18,7 +18,9 @@ namespace storm {
z3::expr Z3ExpressionAdapter::translateExpression(storm::expressions::Expression const& expression) { z3::expr Z3ExpressionAdapter::translateExpression(storm::expressions::Expression const& expression) {
STORM_LOG_ASSERT(expression.getManager() == this->manager, "Invalid expression for solver."); STORM_LOG_ASSERT(expression.getManager() == this->manager, "Invalid expression for solver.");
z3::expr result = boost::any_cast<z3::expr>(expression.getBaseExpression().accept(*this, boost::none)); z3::expr result = boost::any_cast<z3::expr>(expression.getBaseExpression().accept(*this, boost::none));
expressionCache.clear();
for (z3::expr const& assertion : additionalAssertions) { for (z3::expr const& assertion : additionalAssertions) {
result = result && assertion; result = result && assertion;
@ -160,124 +162,220 @@ namespace storm {
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) {
auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data)); z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data));
z3::expr result(context);
switch(expression.getOperatorType()) { switch(expression.getOperatorType()) {
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And:
return leftResult && rightResult; result = leftResult && rightResult;
break;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or:
return leftResult || rightResult; result = leftResult || rightResult;
break;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor: case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor:
return z3::expr(context, Z3_mk_xor(context, leftResult, rightResult)); result = z3::expr(context, Z3_mk_xor(context, leftResult, rightResult));
break;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies:
return z3::expr(context, Z3_mk_implies(context, leftResult, rightResult)); result = z3::expr(context, Z3_mk_implies(context, leftResult, rightResult));
break;
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff:
return z3::expr(context, Z3_mk_iff(context, leftResult, rightResult)); result = z3::expr(context, Z3_mk_iff(context, leftResult, rightResult));
break;
default: default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << "."); STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
} }
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) {
auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data)); z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data));
z3::expr result(context);
switch(expression.getOperatorType()) { switch(expression.getOperatorType()) {
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus:
return leftResult + rightResult; result = leftResult + rightResult;
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus:
return leftResult - rightResult; result = leftResult - rightResult;
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times:
return leftResult * rightResult; result = leftResult * rightResult;
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide:
return leftResult / rightResult; result = leftResult / rightResult;
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min:
return ite(leftResult <= rightResult, leftResult, rightResult); result = ite(leftResult <= rightResult, leftResult, rightResult);
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max:
return ite(leftResult >= rightResult, leftResult, rightResult); result = ite(leftResult >= rightResult, leftResult, rightResult);
break;
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Power: case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Power:
return pw(leftResult,rightResult); result = pw(leftResult,rightResult);
break;
default: default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << "."); STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
} }
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) {
auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data)); z3::expr leftResult = boost::any_cast<z3::expr>(expression.getFirstOperand()->accept(*this, data));
z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast<z3::expr>(expression.getSecondOperand()->accept(*this, data));
z3::expr result(context);
switch(expression.getRelationType()) { switch(expression.getRelationType()) {
case storm::expressions::BinaryRelationExpression::RelationType::Equal: case storm::expressions::BinaryRelationExpression::RelationType::Equal:
return leftResult == rightResult; result = leftResult == rightResult;
break;
case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: case storm::expressions::BinaryRelationExpression::RelationType::NotEqual:
return leftResult != rightResult; result = leftResult != rightResult;
break;
case storm::expressions::BinaryRelationExpression::RelationType::Less: case storm::expressions::BinaryRelationExpression::RelationType::Less:
return leftResult < rightResult; result = leftResult < rightResult;
break;
case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual:
return leftResult <= rightResult; result = leftResult <= rightResult;
break;
case storm::expressions::BinaryRelationExpression::RelationType::Greater: case storm::expressions::BinaryRelationExpression::RelationType::Greater:
return leftResult > rightResult; result = leftResult > rightResult;
break;
case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual:
return leftResult >= rightResult; result = leftResult >= rightResult;
break;
default: default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getRelationType()) << "' in expression " << expression << "."); STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getRelationType()) << "' in expression " << expression << ".");
} }
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) { boost::any Z3ExpressionAdapter::visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) {
return context.bool_val(expression.getValue()); auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr result = context.bool_val(expression.getValue());
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) { boost::any Z3ExpressionAdapter::visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) {
auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
std::stringstream fractionStream; std::stringstream fractionStream;
fractionStream << expression.getValue(); fractionStream << expression.getValue();
return context.real_val(fractionStream.str().c_str()); z3::expr result = context.real_val(fractionStream.str().c_str());
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) { boost::any Z3ExpressionAdapter::visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) {
return context.int_val(static_cast<long long>(expression.getValue())); auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr result = context.int_val(static_cast<long long>(expression.getValue()));
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) {
z3::expr childResult = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this, data)); auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr result = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this, data));
switch (expression.getOperatorType()) { switch (expression.getOperatorType()) {
case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not:
return !childResult; result = !result;
break;
default: default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << "."); STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<int>(expression.getOperatorType()) << "' in expression " << expression << ".");
} }
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) {
z3::expr childResult = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this, data)); auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr result = boost::any_cast<z3::expr>(expression.getOperand()->accept(*this, data));
switch(expression.getOperatorType()) { switch(expression.getOperatorType()) {
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus:
return 0 - childResult; result = 0 - result;
break;
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: {
storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true);
z3::expr floorVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); z3::expr floorVariable = context.int_const(freshAuxiliaryVariable.getName().c_str());
additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= result && result < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1));
return floorVariable; result = floorVariable;
break;
} }
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{
storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true);
z3::expr ceilVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); z3::expr ceilVariable = context.int_const(freshAuxiliaryVariable.getName().c_str());
additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= result && result < z3::expr(context, Z3_mk_int2real(context, ceilVariable)));
return ceilVariable; result = ceilVariable;
break;
} }
default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast<int>(expression.getOperatorType()) << "'."); default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast<int>(expression.getOperatorType()) << "'.");
} }
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) { boost::any Z3ExpressionAdapter::visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) {
auto cacheIt = expressionCache.find(&expression);
if (cacheIt != expressionCache.end()) {
return cacheIt->second;
}
z3::expr conditionResult = boost::any_cast<z3::expr>(expression.getCondition()->accept(*this, data)); z3::expr conditionResult = boost::any_cast<z3::expr>(expression.getCondition()->accept(*this, data));
z3::expr thenResult = boost::any_cast<z3::expr>(expression.getThenExpression()->accept(*this, data)); z3::expr thenResult = boost::any_cast<z3::expr>(expression.getThenExpression()->accept(*this, data));
z3::expr elseResult = boost::any_cast<z3::expr>(expression.getElseExpression()->accept(*this, data)); z3::expr elseResult = boost::any_cast<z3::expr>(expression.getElseExpression()->accept(*this, data));
return z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult)); z3::expr result = z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult));
expressionCache.emplace(&expression, result);
return result;
} }
boost::any Z3ExpressionAdapter::visit(storm::expressions::VariableExpression const& expression, boost::any const&) { boost::any Z3ExpressionAdapter::visit(storm::expressions::VariableExpression const& expression, boost::any const&) {

7
src/storm/adapters/Z3ExpressionAdapter.h

@ -16,6 +16,10 @@
#include "storm/storage/expressions/ExpressionVisitor.h" #include "storm/storage/expressions/ExpressionVisitor.h"
namespace storm { namespace storm {
namespace expressions {
class BaseExpression;
}
namespace adapters { namespace adapters {
#ifdef STORM_HAVE_Z3 #ifdef STORM_HAVE_Z3
@ -99,6 +103,9 @@ namespace storm {
// A mapping from z3 declarations to the corresponding variables. // A mapping from z3 declarations to the corresponding variables.
std::unordered_map<Z3_func_decl, storm::expressions::Variable> declarationToVariableMapping; std::unordered_map<Z3_func_decl, storm::expressions::Variable> declarationToVariableMapping;
// A cache of already translated constraints. Only valid during the translation of one expression.
std::unordered_map<storm::expressions::BaseExpression const*, z3::expr> expressionCache;
}; };
#endif #endif
} // namespace adapters } // namespace adapters

22
src/storm/counterexamples/SMTMinimalLabelSetGenerator.h

@ -1017,8 +1017,16 @@ namespace storm {
* result bit. * result bit.
*/ */
static std::pair<storm::expressions::Expression, storm::expressions::Expression> createFullAdder(storm::expressions::Expression in1, storm::expressions::Expression in2, storm::expressions::Expression carryIn) { static std::pair<storm::expressions::Expression, storm::expressions::Expression> createFullAdder(storm::expressions::Expression in1, storm::expressions::Expression in2, storm::expressions::Expression carryIn) {
storm::expressions::Expression resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn); storm::expressions::Expression resultBit;
storm::expressions::Expression carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; storm::expressions::Expression carryBit;
if (carryIn.isFalse()) {
resultBit = (in1 && !in2) || (!in1 && in2);
carryBit = in1 && in2;
} else {
resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn);
carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn;
}
return std::make_pair(carryBit, resultBit); return std::make_pair(carryBit, resultBit);
} }
@ -1090,8 +1098,6 @@ namespace storm {
* @return A bit vector representing the number of literals that are set to true. * @return A bit vector representing the number of literals that are set to true.
*/ */
static std::vector<storm::expressions::Expression> createCounterCircuit(VariableInformation const& variableInformation, std::vector<storm::expressions::Variable> const& literals) { static std::vector<storm::expressions::Expression> createCounterCircuit(VariableInformation const& variableInformation, std::vector<storm::expressions::Variable> const& literals) {
STORM_LOG_DEBUG("Creating counter circuit for " << literals.size() << " literals.");
if (literals.empty()) { if (literals.empty()) {
return std::vector<storm::expressions::Expression>(); return std::vector<storm::expressions::Expression>();
} }
@ -1107,6 +1113,8 @@ namespace storm {
aux = createAdderPairs(variableInformation, aux); aux = createAdderPairs(variableInformation, aux);
} }
STORM_LOG_DEBUG("Created counter circuit for " << literals.size() << " literals.");
return aux.front(); return aux.front();
} }
@ -1307,6 +1315,7 @@ namespace storm {
* @param variableInformation A structure with information about the variables of the solver. * @param variableInformation A structure with information about the variables of the solver.
*/ */
static std::vector<storm::expressions::Variable> assertAdder(storm::solver::SmtSolver& solver, VariableInformation const& variableInformation) { static std::vector<storm::expressions::Variable> assertAdder(storm::solver::SmtSolver& solver, VariableInformation const& variableInformation) {
auto start = std::chrono::high_resolution_clock::now();
std::stringstream variableName; std::stringstream variableName;
std::vector<storm::expressions::Variable> result; std::vector<storm::expressions::Variable> result;
@ -1317,8 +1326,13 @@ namespace storm {
variableName << "adder" << i; variableName << "adder" << i;
result.push_back(variableInformation.manager->declareBooleanVariable(variableName.str())); result.push_back(variableInformation.manager->declareBooleanVariable(variableName.str()));
solver.add(storm::expressions::implies(adderVariables[i], result.back())); solver.add(storm::expressions::implies(adderVariables[i], result.back()));
STORM_LOG_TRACE("Added bit " << i << " of adder.");
} }
auto end = std::chrono::high_resolution_clock::now();
uint64_t duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count();
STORM_LOG_DEBUG("Asserted adder in " << duration << "ms.");
return result; return result;
} }

|||||||
100:0
Loading…
Cancel
Save