diff --git a/src/storm/adapters/Z3ExpressionAdapter.cpp b/src/storm/adapters/Z3ExpressionAdapter.cpp index e55f5fe3c..c7e27756d 100644 --- a/src/storm/adapters/Z3ExpressionAdapter.cpp +++ b/src/storm/adapters/Z3ExpressionAdapter.cpp @@ -18,8 +18,10 @@ namespace storm { z3::expr Z3ExpressionAdapter::translateExpression(storm::expressions::Expression const& expression) { STORM_LOG_ASSERT(expression.getManager() == this->manager, "Invalid expression for solver."); + z3::expr result = boost::any_cast(expression.getBaseExpression().accept(*this, boost::none)); - + expressionCache.clear(); + for (z3::expr const& assertion : additionalAssertions) { result = result && assertion; } @@ -160,124 +162,220 @@ namespace storm { } boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) { + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + z3::expr result(context); switch(expression.getOperatorType()) { case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: - return leftResult && rightResult; + result = leftResult && rightResult; + break; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: - return leftResult || rightResult; + result = leftResult || rightResult; + break; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor: - return z3::expr(context, Z3_mk_xor(context, leftResult, rightResult)); + result = z3::expr(context, Z3_mk_xor(context, leftResult, rightResult)); + break; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: - return z3::expr(context, Z3_mk_implies(context, leftResult, rightResult)); + result = z3::expr(context, Z3_mk_implies(context, leftResult, rightResult)); + break; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: - return z3::expr(context, Z3_mk_iff(context, leftResult, rightResult)); + result = z3::expr(context, Z3_mk_iff(context, leftResult, rightResult)); + break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) { + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + z3::expr result(context); switch(expression.getOperatorType()) { case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: - return leftResult + rightResult; + result = leftResult + rightResult; + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: - return leftResult - rightResult; + result = leftResult - rightResult; + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: - return leftResult * rightResult; + result = leftResult * rightResult; + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: - return leftResult / rightResult; + result = leftResult / rightResult; + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: - return ite(leftResult <= rightResult, leftResult, rightResult); + result = ite(leftResult <= rightResult, leftResult, rightResult); + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: - return ite(leftResult >= rightResult, leftResult, rightResult); + result = ite(leftResult >= rightResult, leftResult, rightResult); + break; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Power: - return pw(leftResult,rightResult); + result = pw(leftResult,rightResult); + break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) { + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this, data)); z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + z3::expr result(context); switch(expression.getRelationType()) { case storm::expressions::BinaryRelationExpression::RelationType::Equal: - return leftResult == rightResult; + result = leftResult == rightResult; + break; case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: - return leftResult != rightResult; + result = leftResult != rightResult; + break; case storm::expressions::BinaryRelationExpression::RelationType::Less: - return leftResult < rightResult; + result = leftResult < rightResult; + break; case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: - return leftResult <= rightResult; + result = leftResult <= rightResult; + break; case storm::expressions::BinaryRelationExpression::RelationType::Greater: - return leftResult > rightResult; + result = leftResult > rightResult; + break; case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: - return leftResult >= rightResult; + result = leftResult >= rightResult; + break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getRelationType()) << "' in expression " << expression << "."); } + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) { - return context.bool_val(expression.getValue()); + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + + z3::expr result = context.bool_val(expression.getValue()); + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) { + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + std::stringstream fractionStream; fractionStream << expression.getValue(); - return context.real_val(fractionStream.str().c_str()); + z3::expr result = context.real_val(fractionStream.str().c_str()); + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) { - return context.int_val(static_cast(expression.getValue())); + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + + z3::expr result = context.int_val(static_cast(expression.getValue())); + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) { - z3::expr childResult = boost::any_cast(expression.getOperand()->accept(*this, data)); + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + + z3::expr result = boost::any_cast(expression.getOperand()->accept(*this, data)); switch (expression.getOperatorType()) { case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: - return !childResult; + result = !result; + break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); - } + } + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) { - z3::expr childResult = boost::any_cast(expression.getOperand()->accept(*this, data)); + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + + z3::expr result = boost::any_cast(expression.getOperand()->accept(*this, data)); switch(expression.getOperatorType()) { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: - return 0 - childResult; + result = 0 - result; + break; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); z3::expr floorVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); - additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); - return floorVariable; + additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= result && result < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); + result = floorVariable; + break; } case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ storm::expressions::Variable freshAuxiliaryVariable = manager.declareFreshVariable(manager.getIntegerType(), true); z3::expr ceilVariable = context.int_const(freshAuxiliaryVariable.getName().c_str()); - additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); - return ceilVariable; + additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= result && result < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); + result = ceilVariable; + break; } default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast(expression.getOperatorType()) << "'."); } + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) { + auto cacheIt = expressionCache.find(&expression); + if (cacheIt != expressionCache.end()) { + return cacheIt->second; + } + z3::expr conditionResult = boost::any_cast(expression.getCondition()->accept(*this, data)); z3::expr thenResult = boost::any_cast(expression.getThenExpression()->accept(*this, data)); z3::expr elseResult = boost::any_cast(expression.getElseExpression()->accept(*this, data)); - return z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult)); + z3::expr result = z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult)); + + expressionCache.emplace(&expression, result); + return result; } boost::any Z3ExpressionAdapter::visit(storm::expressions::VariableExpression const& expression, boost::any const&) { diff --git a/src/storm/adapters/Z3ExpressionAdapter.h b/src/storm/adapters/Z3ExpressionAdapter.h index ddc6e185c..11ed5fe8a 100644 --- a/src/storm/adapters/Z3ExpressionAdapter.h +++ b/src/storm/adapters/Z3ExpressionAdapter.h @@ -16,6 +16,10 @@ #include "storm/storage/expressions/ExpressionVisitor.h" namespace storm { + namespace expressions { + class BaseExpression; + } + namespace adapters { #ifdef STORM_HAVE_Z3 @@ -99,6 +103,9 @@ namespace storm { // A mapping from z3 declarations to the corresponding variables. std::unordered_map declarationToVariableMapping; + + // A cache of already translated constraints. Only valid during the translation of one expression. + std::unordered_map expressionCache; }; #endif } // namespace adapters diff --git a/src/storm/counterexamples/SMTMinimalLabelSetGenerator.h b/src/storm/counterexamples/SMTMinimalLabelSetGenerator.h index ddfcdaf6b..79bb2591f 100644 --- a/src/storm/counterexamples/SMTMinimalLabelSetGenerator.h +++ b/src/storm/counterexamples/SMTMinimalLabelSetGenerator.h @@ -1017,8 +1017,16 @@ namespace storm { * result bit. */ static std::pair createFullAdder(storm::expressions::Expression in1, storm::expressions::Expression in2, storm::expressions::Expression carryIn) { - storm::expressions::Expression resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn); - storm::expressions::Expression carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; + storm::expressions::Expression resultBit; + storm::expressions::Expression carryBit; + + if (carryIn.isFalse()) { + resultBit = (in1 && !in2) || (!in1 && in2); + carryBit = in1 && in2; + } else { + resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn) || (in1 && in2 && carryIn); + carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; + } return std::make_pair(carryBit, resultBit); } @@ -1090,8 +1098,6 @@ namespace storm { * @return A bit vector representing the number of literals that are set to true. */ static std::vector createCounterCircuit(VariableInformation const& variableInformation, std::vector const& literals) { - STORM_LOG_DEBUG("Creating counter circuit for " << literals.size() << " literals."); - if (literals.empty()) { return std::vector(); } @@ -1106,7 +1112,9 @@ namespace storm { while (aux.size() > 1) { aux = createAdderPairs(variableInformation, aux); } - + + STORM_LOG_DEBUG("Created counter circuit for " << literals.size() << " literals."); + return aux.front(); } @@ -1307,6 +1315,7 @@ namespace storm { * @param variableInformation A structure with information about the variables of the solver. */ static std::vector assertAdder(storm::solver::SmtSolver& solver, VariableInformation const& variableInformation) { + auto start = std::chrono::high_resolution_clock::now(); std::stringstream variableName; std::vector result; @@ -1317,8 +1326,13 @@ namespace storm { variableName << "adder" << i; result.push_back(variableInformation.manager->declareBooleanVariable(variableName.str())); solver.add(storm::expressions::implies(adderVariables[i], result.back())); + STORM_LOG_TRACE("Added bit " << i << " of adder."); } - + + auto end = std::chrono::high_resolution_clock::now(); + uint64_t duration = std::chrono::duration_cast(end - start).count(); + STORM_LOG_DEBUG("Asserted adder in " << duration << "ms."); + return result; }