diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index e120f4a59..25804b58a 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -88,62 +88,6 @@ namespace storm { std::vector> choiceLabeling; }; - static std::map parseConstantDefinitionString(storm::prism::Program const& program, std::string const& constantDefinitionString) { - std::map constantDefinitions; - std::set definedConstants; - - if (!constantDefinitionString.empty()) { - // Parse the string that defines the undefined constants of the model and make sure that it contains exactly - // one value for each undefined constant of the model. - std::vector definitions; - boost::split(definitions, constantDefinitionString, boost::is_any_of(",")); - for (auto& definition : definitions) { - boost::trim(definition); - - // Check whether the token could be a legal constant definition. - uint_fast64_t positionOfAssignmentOperator = definition.find('='); - if (positionOfAssignmentOperator == std::string::npos) { - throw storm::exceptions::InvalidArgumentException() << "Illegal constant definition string: syntax error."; - } - - // Now extract the variable name and the value from the string. - std::string constantName = definition.substr(0, positionOfAssignmentOperator); - boost::trim(constantName); - std::string value = definition.substr(positionOfAssignmentOperator + 1); - boost::trim(value); - - // Check whether the constant is a legal undefined constant of the program and if so, of what type it is. - if (program.hasConstant(constantName)) { - // Get the actual constant and check whether it's in fact undefined. - auto const& constant = program.getConstant(constantName); - LOG_THROW(!constant.isDefined(), storm::exceptions::InvalidArgumentException, "Illegally trying to define already defined constant '" << constantName <<"'."); - LOG_THROW(definedConstants.find(constantName) == definedConstants.end(), storm::exceptions::InvalidArgumentException, "Illegally trying to define constant '" << constantName <<"' twice."); - definedConstants.insert(constantName); - - if (constant.getType() == storm::expressions::ExpressionReturnType::Bool) { - if (value == "true") { - constantDefinitions[constantName] = storm::expressions::Expression::createTrue(); - } else if (value == "false") { - constantDefinitions[constantName] = storm::expressions::Expression::createFalse(); - } else { - throw storm::exceptions::InvalidArgumentException() << "Illegal value for boolean constant: " << value << "."; - } - } else if (constant.getType() == storm::expressions::ExpressionReturnType::Int) { - int_fast64_t integerValue = std::stoi(value); - constantDefinitions[constantName] = storm::expressions::Expression::createIntegerLiteral(integerValue); - } else if (constant.getType() == storm::expressions::ExpressionReturnType::Double) { - double doubleValue = std::stod(value); - constantDefinitions[constantName] = storm::expressions::Expression::createDoubleLiteral(doubleValue); - } - } else { - throw storm::exceptions::InvalidArgumentException() << "Illegal constant definition string: unknown undefined constant " << constantName << "."; - } - } - } - - return constantDefinitions; - } - /*! * Convert the program given at construction time to an abstract model. The type of the model is the one * specified in the program. The given reward model name selects the rewards that the model will contain. @@ -159,7 +103,7 @@ namespace storm { static std::unique_ptr> translateProgram(storm::prism::Program program, std::string const& constantDefinitionString = "", std::string const& rewardModelName = "") { // Start by defining the undefined constants in the model. // First, we need to parse the constant definition string. - std::map constantDefinitions = parseConstantDefinitionString(program, constantDefinitionString); + std::map constantDefinitions = storm::utility::prism::parseConstantDefinitionString(program, constantDefinitionString); storm::prism::Program preparedProgram = program.defineUndefinedConstants(constantDefinitions); LOG_THROW(!preparedProgram.hasUndefinedConstants(), storm::exceptions::InvalidArgumentException, "Program still contains undefined constants."); diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h index 09bc1dac4..7f850f9d3 100644 --- a/src/adapters/Z3ExpressionAdapter.h +++ b/src/adapters/Z3ExpressionAdapter.h @@ -10,96 +10,304 @@ #include -#include "src/ir/expressions/ExpressionVisitor.h" -#include "src/ir/expressions/Expressions.h" +// Include the headers of Z3 only if it is available. +#ifdef STORM_HAVE_Z3 +#include "z3++.h" +#include "z3.h" +#endif + +#include "storm-config.h" +#include "src/storage/expressions/Expressions.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/ExpressionEvaluationException.h" +#include "src/exceptions/InvalidTypeException.h" +#include "src/exceptions/NotImplementedException.h" namespace storm { namespace adapters { - - class Z3ExpressionAdapter : public storm::ir::expressions::ExpressionVisitor { + +#ifdef STORM_HAVE_Z3 + class Z3ExpressionAdapter : public storm::expressions::ExpressionVisitor { public: /*! * Creates a Z3ExpressionAdapter over the given Z3 context. + * + * @remark The adapter internally creates helper variables prefixed with `__z3adapter_`. Avoid having variables with + * this prefix in the variableToExpressionMap, as this might lead to unexpected results. * * @param context A reference to the Z3 context over which to build the expressions. Be careful to guarantee * the lifetime of the context as long as the instance of this adapter is used. * @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions. */ - Z3ExpressionAdapter(z3::context& context, std::map const& variableToExpressionMap) : context(context), stack(), variableToExpressionMap(variableToExpressionMap) { + Z3ExpressionAdapter(z3::context& context, std::map const& variableToExpressionMap) + : context(context) + , stack() + , additionalAssertions() + , additionalVariableCounter(0) + , variableToExpressionMap(variableToExpressionMap) { // Intentionally left empty. } /*! * Translates the given expression to an equivalent expression for Z3. - * + * + * @remark The adapter internally creates helper variables prefixed with `__z3adapter_`. Avoid having variables with + * this prefix in the expression, as this might lead to unexpected results. + * * @param expression The expression to translate. + * @param createZ3Variables If set to true a solver variable is created for each variable in expression that is not + * yet known to the adapter. (i.e. values from the variableToExpressionMap passed to the constructor + * are not overwritten) * @return An equivalent expression for Z3. */ - z3::expr translateExpression(std::unique_ptr const& expression) { - expression->accept(this); + z3::expr translateExpression(storm::expressions::Expression const& expression, bool createZ3Variables = false) { + if (createZ3Variables) { + std::map variables; + + try { + variables = expression.getVariablesAndTypes(); + } + catch (storm::exceptions::InvalidTypeException* e) { + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with ambigious type while trying to autocreate solver variables: " << e); + } + + for (auto variableAndType : variables) { + if (this->variableToExpressionMap.find(variableAndType.first) == this->variableToExpressionMap.end()) { + switch (variableAndType.second) + { + case storm::expressions::ExpressionReturnType::Bool: + this->variableToExpressionMap.insert(std::make_pair(variableAndType.first, context.bool_const(variableAndType.first.c_str()))); + break; + case storm::expressions::ExpressionReturnType::Int: + this->variableToExpressionMap.insert(std::make_pair(variableAndType.first, context.int_const(variableAndType.first.c_str()))); + break; + case storm::expressions::ExpressionReturnType::Double: + this->variableToExpressionMap.insert(std::make_pair(variableAndType.first, context.real_const(variableAndType.first.c_str()))); + break; + default: + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with unknown type while trying to autocreate solver variables: " << variableAndType.first); + break; + } + } + } + } + + expression.getBaseExpression().accept(this); z3::expr result = stack.top(); stack.pop(); + + while (!additionalAssertions.empty()) { + result = result && additionalAssertions.top(); + additionalAssertions.pop(); + } + return result; } - virtual void visit(ir::expressions::BinaryBooleanFunctionExpression* expression) { - expression->getLeft()->accept(this); - expression->getRight()->accept(this); + storm::expressions::Expression translateExpression(z3::expr const& expr) { + //std::cout << std::boolalpha << expr.is_var() << std::endl; + //std::cout << std::boolalpha << expr.is_app() << std::endl; + //std::cout << expr.decl().decl_kind() << std::endl; + + /* + if (expr.is_bool() && expr.is_const()) { + switch (Z3_get_bool_value(expr.ctx(), expr)) { + case Z3_L_FALSE: + return storm::expressions::Expression::createFalse(); + case Z3_L_TRUE: + return storm::expressions::Expression::createTrue(); + break; + default: + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant boolean, but value is undefined."); + break; + } + } else if (expr.is_int() && expr.is_const()) { + int_fast64_t value; + if (Z3_get_numeral_int64(expr.ctx(), expr, &value)) { + return storm::expressions::Expression::createIntegerLiteral(value); + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant integer and value does not fit into 64-bit integer."); + } + } else if (expr.is_real() && expr.is_const()) { + int_fast64_t num; + int_fast64_t den; + if (Z3_get_numeral_rational_int64(expr.ctx(), expr, &num, &den)) { + return storm::expressions::Expression::createDoubleLiteral(static_cast(num) / static_cast(den)); + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant real and value does not fit into a fraction with 64-bit integer numerator and denominator."); + } + } else */ + if (expr.is_app()) { + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + return storm::expressions::Expression::createTrue(); + case Z3_OP_FALSE: + return storm::expressions::Expression::createFalse(); + case Z3_OP_EQ: + return this->translateExpression(expr.arg(0)) == this->translateExpression(expr.arg(1)); + case Z3_OP_ITE: + return this->translateExpression(expr.arg(0)).ite(this->translateExpression(expr.arg(1)), this->translateExpression(expr.arg(2))); + case Z3_OP_AND: { + unsigned args = expr.num_args(); + LOG_THROW(args != 0, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. 0-ary AND is assumed to be an error."); + if (args == 1) { + return this->translateExpression(expr.arg(0)); + } else { + storm::expressions::Expression retVal = this->translateExpression(expr.arg(0)); + for (unsigned i = 1; i < args; i++) { + retVal = retVal && this->translateExpression(expr.arg(i)); + } + return retVal; + } + } + case Z3_OP_OR: { + unsigned args = expr.num_args(); + LOG_THROW(args != 0, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. 0-ary OR is assumed to be an error."); + if (args == 1) { + return this->translateExpression(expr.arg(0)); + } else { + storm::expressions::Expression retVal = this->translateExpression(expr.arg(0)); + for (unsigned i = 1; i < args; i++) { + retVal = retVal || this->translateExpression(expr.arg(i)); + } + return retVal; + } + } + case Z3_OP_IFF: + return this->translateExpression(expr.arg(0)).iff(this->translateExpression(expr.arg(1))); + case Z3_OP_XOR: + return this->translateExpression(expr.arg(0)) ^ this->translateExpression(expr.arg(1)); + case Z3_OP_NOT: + return !this->translateExpression(expr.arg(0)); + case Z3_OP_IMPLIES: + return this->translateExpression(expr.arg(0)).implies(this->translateExpression(expr.arg(1))); + case Z3_OP_LE: + return this->translateExpression(expr.arg(0)) <= this->translateExpression(expr.arg(1)); + case Z3_OP_GE: + return this->translateExpression(expr.arg(0)) >= this->translateExpression(expr.arg(1)); + case Z3_OP_LT: + return this->translateExpression(expr.arg(0)) < this->translateExpression(expr.arg(1)); + case Z3_OP_GT: + return this->translateExpression(expr.arg(0)) > this->translateExpression(expr.arg(1)); + case Z3_OP_ADD: + return this->translateExpression(expr.arg(0)) + this->translateExpression(expr.arg(1)); + case Z3_OP_SUB: + return this->translateExpression(expr.arg(0)) - this->translateExpression(expr.arg(1)); + case Z3_OP_UMINUS: + return -this->translateExpression(expr.arg(0)); + case Z3_OP_MUL: + return this->translateExpression(expr.arg(0)) * this->translateExpression(expr.arg(1)); + case Z3_OP_DIV: + return this->translateExpression(expr.arg(0)) / this->translateExpression(expr.arg(1)); + case Z3_OP_IDIV: + return this->translateExpression(expr.arg(0)) / this->translateExpression(expr.arg(1)); + case Z3_OP_ANUM: + //Arithmetic numeral + if (expr.is_int() && expr.is_const()) { + long long value; + if (Z3_get_numeral_int64(expr.ctx(), expr, &value)) { + return storm::expressions::Expression::createIntegerLiteral(value); + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant integer and value does not fit into 64-bit integer."); + } + } else if (expr.is_real() && expr.is_const()) { + long long num; + long long den; + if (Z3_get_numeral_rational_int64(expr.ctx(), expr, &num, &den)) { + return storm::expressions::Expression::createDoubleLiteral(static_cast(num) / static_cast(den)); + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Expression is constant real and value does not fit into a fraction with 64-bit integer numerator and denominator."); + } + } + case Z3_OP_UNINTERPRETED: + //storm only supports uninterpreted constant functions + LOG_THROW(expr.is_const(), storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered non constant uninterpreted function."); + if (expr.is_bool()) { + return storm::expressions::Expression::createBooleanVariable(expr.decl().name().str()); + } else if (expr.is_int()) { + return storm::expressions::Expression::createIntegerVariable(expr.decl().name().str()); + } else if (expr.is_real()) { + return storm::expressions::Expression::createDoubleVariable(expr.decl().name().str()); + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered constant uninterpreted function of unknown sort."); + } + + default: + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered unhandled Z3_decl_kind " << expr.decl().kind() <<"."); + break; + } + } else { + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Failed to convert Z3 expression. Encountered unknown expression type."); + } + } + + virtual void visit(storm::expressions::BinaryBooleanFunctionExpression const* expression) override { + expression->getFirstOperand()->accept(this); + expression->getSecondOperand()->accept(this); - z3::expr rightResult = stack.top(); + const z3::expr rightResult = stack.top(); stack.pop(); - z3::expr leftResult = stack.top(); + const z3::expr leftResult = stack.top(); stack.pop(); - switch(expression->getFunctionType()) { - case storm::ir::expressions::BinaryBooleanFunctionExpression::AND: + switch(expression->getOperatorType()) { + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: stack.push(leftResult && rightResult); break; - case storm::ir::expressions::BinaryBooleanFunctionExpression::OR: + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: stack.push(leftResult || rightResult); - break; + break; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor: + stack.push(z3::expr(context, Z3_mk_xor(context, leftResult, rightResult))); + break; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: + stack.push(z3::expr(context, Z3_mk_implies(context, leftResult, rightResult))); + break; + case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: + stack.push(z3::expr(context, Z3_mk_iff(context, leftResult, rightResult))); + break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } - virtual void visit(ir::expressions::BinaryNumericalFunctionExpression* expression) { - expression->getLeft()->accept(this); - expression->getRight()->accept(this); + virtual void visit(storm::expressions::BinaryNumericalFunctionExpression const* expression) override { + expression->getFirstOperand()->accept(this); + expression->getSecondOperand()->accept(this); z3::expr rightResult = stack.top(); stack.pop(); z3::expr leftResult = stack.top(); stack.pop(); - switch(expression->getFunctionType()) { - case storm::ir::expressions::BinaryNumericalFunctionExpression::PLUS: + switch(expression->getOperatorType()) { + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: stack.push(leftResult + rightResult); break; - case storm::ir::expressions::BinaryNumericalFunctionExpression::MINUS: + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: stack.push(leftResult - rightResult); break; - case storm::ir::expressions::BinaryNumericalFunctionExpression::TIMES: + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: stack.push(leftResult * rightResult); break; - case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE: + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: stack.push(leftResult / rightResult); break; - case storm::ir::expressions::BinaryNumericalFunctionExpression::MIN: + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: stack.push(ite(leftResult <= rightResult, leftResult, rightResult)); break; - case storm::ir::expressions::BinaryNumericalFunctionExpression::MAX: + case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: stack.push(ite(leftResult >= rightResult, leftResult, rightResult)); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << "."; + << "Unknown numerical binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } - virtual void visit(ir::expressions::BinaryRelationExpression* expression) { - expression->getLeft()->accept(this); - expression->getRight()->accept(this); + virtual void visit(storm::expressions::BinaryRelationExpression const* expression) override { + expression->getFirstOperand()->accept(this); + expression->getSecondOperand()->accept(this); z3::expr rightResult = stack.top(); stack.pop(); @@ -107,112 +315,115 @@ namespace storm { stack.pop(); switch(expression->getRelationType()) { - case storm::ir::expressions::BinaryRelationExpression::EQUAL: + case storm::expressions::BinaryRelationExpression::RelationType::Equal: stack.push(leftResult == rightResult); break; - case storm::ir::expressions::BinaryRelationExpression::NOT_EQUAL: + case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: stack.push(leftResult != rightResult); break; - case storm::ir::expressions::BinaryRelationExpression::LESS: + case storm::expressions::BinaryRelationExpression::RelationType::Less: stack.push(leftResult < rightResult); break; - case storm::ir::expressions::BinaryRelationExpression::LESS_OR_EQUAL: + case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: stack.push(leftResult <= rightResult); break; - case storm::ir::expressions::BinaryRelationExpression::GREATER: + case storm::expressions::BinaryRelationExpression::RelationType::Greater: stack.push(leftResult > rightResult); break; - case storm::ir::expressions::BinaryRelationExpression::GREATER_OR_EQUAL: + case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: stack.push(leftResult >= rightResult); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getRelationType() << "' in expression " << expression->toString() << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."; } } - virtual void visit(ir::expressions::BooleanConstantExpression* expression) { - if (!expression->isDefined()) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << ". Boolean constant '" << expression->getConstantName() << "' is undefined."; - } - - stack.push(context.bool_val(expression->getValue())); - } - - virtual void visit(ir::expressions::BooleanLiteralExpression* expression) { - stack.push(context.bool_val(expression->getValueAsBool(nullptr))); + virtual void visit(storm::expressions::BooleanLiteralExpression const* expression) override { + stack.push(context.bool_val(expression->evaluateAsBool())); } - virtual void visit(ir::expressions::DoubleConstantExpression* expression) { - if (!expression->isDefined()) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << ". Double constant '" << expression->getConstantName() << "' is undefined."; - } - + virtual void visit(storm::expressions::DoubleLiteralExpression const* expression) override { std::stringstream fractionStream; - fractionStream << expression->getValue(); + fractionStream << expression->evaluateAsDouble(); stack.push(context.real_val(fractionStream.str().c_str())); } - virtual void visit(ir::expressions::DoubleLiteralExpression* expression) { - std::stringstream fractionStream; - fractionStream << expression->getValueAsDouble(nullptr); - stack.push(context.real_val(fractionStream.str().c_str())); + virtual void visit(storm::expressions::IntegerLiteralExpression const* expression) override { + stack.push(context.int_val(static_cast(expression->evaluateAsInt()))); } - virtual void visit(ir::expressions::IntegerConstantExpression* expression) { - if (!expression->isDefined()) { - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << ". Integer constant '" << expression->getConstantName() << "' is undefined."; - } - - stack.push(context.int_val(static_cast(expression->getValue()))); - } - - virtual void visit(ir::expressions::IntegerLiteralExpression* expression) { - stack.push(context.int_val(static_cast(expression->getValueAsInt(nullptr)))); - } - - virtual void visit(ir::expressions::UnaryBooleanFunctionExpression* expression) { - expression->getChild()->accept(this); + virtual void visit(storm::expressions::UnaryBooleanFunctionExpression const* expression) override { + expression->getOperand()->accept(this); z3::expr childResult = stack.top(); stack.pop(); - switch (expression->getFunctionType()) { - case storm::ir::expressions::UnaryBooleanFunctionExpression::NOT: + switch (expression->getOperatorType()) { + case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: stack.push(!childResult); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } - virtual void visit(ir::expressions::UnaryNumericalFunctionExpression* expression) { - expression->getChild()->accept(this); + virtual void visit(storm::expressions::UnaryNumericalFunctionExpression const* expression) override { + expression->getOperand()->accept(this); z3::expr childResult = stack.top(); stack.pop(); - switch(expression->getFunctionType()) { - case storm::ir::expressions::UnaryNumericalFunctionExpression::MINUS: + switch(expression->getOperatorType()) { + case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: stack.push(0 - childResult); - break; + break; + case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { + z3::expr floorVariable = context.int_const(("__z3adapter_floor_" + std::to_string(additionalVariableCounter++)).c_str()); + additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); + stack.push(floorVariable); + //throw storm::exceptions::NotImplementedException() << "Unary numerical function 'floor' is not supported by Z3ExpressionAdapter."; + break; + } + case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ + z3::expr ceilVariable = context.int_const(("__z3adapter_ceil_" + std::to_string(additionalVariableCounter++)).c_str()); + additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); + stack.push(ceilVariable); + //throw storm::exceptions::NotImplementedException() << "Unary numerical function 'floor' is not supported by Z3ExpressionAdapter."; + break; + } default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical unary operator: '" << expression->getFunctionType() << "'."; + << "Unknown numerical unary operator: '" << static_cast(expression->getOperatorType()) << "'."; } } + + virtual void visit(storm::expressions::IfThenElseExpression const* expression) override { + expression->getCondition()->accept(this); + expression->getThenExpression()->accept(this); + expression->getElseExpression()->accept(this); + + z3::expr conditionResult = stack.top(); + stack.pop(); + z3::expr thenResult = stack.top(); + stack.pop(); + z3::expr elseResult = stack.top(); + stack.pop(); + + stack.push(z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult))); + } - virtual void visit(ir::expressions::VariableExpression* expression) { + virtual void visit(storm::expressions::VariableExpression const* expression) override { stack.push(variableToExpressionMap.at(expression->getVariableName())); } - + private: z3::context& context; std::stack stack; + std::stack additionalAssertions; + uint_fast64_t additionalVariableCounter; + std::map variableToExpressionMap; }; - +#endif } // namespace adapters } // namespace storm diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 2abb27ed2..c942396cb 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -20,6 +20,8 @@ #include "src/adapters/Z3ExpressionAdapter.h" #endif +#include "src/storage/prism/Program.h" +#include "src/storage/expressions/Expression.h" #include "src/adapters/ExplicitModelAdapter.h" #include "src/modelchecker/prctl/SparseMdpPrctlModelChecker.h" #include "src/solver/GmmxxNondeterministicLinearEquationSolver.h" @@ -125,7 +127,7 @@ namespace storm { for (auto const& entry : transitionMatrix.getRow(row)) { // If there is a relevant successor, we need to add the labels of the current choice. - if (relevancyInformation.relevantStates.get(entry.first) || psiStates.get(entry.first)) { + if (relevancyInformation.relevantStates.get(entry.getColumn()) || psiStates.get(entry.getColumn())) { for (auto const& label : choiceLabeling[row]) { relevancyInformation.relevantLabels.insert(label); } @@ -212,21 +214,21 @@ namespace storm { for (auto const& entry : transitionMatrix.getRow(relevantChoice)) { // If the successor state is neither the state itself nor an irrelevant state, we need to add a variable for the transition. - if (state != entry.first && (relevancyInformation.relevantStates.get(entry.first) || psiStates.get(entry.first))) { + if (state != entry.getColumn() && (relevancyInformation.relevantStates.get(entry.getColumn()) || psiStates.get(entry.getColumn()))) { // Make sure that there is not already one variable for the state pair. This may happen because of several nondeterministic choices // targeting the same state. - if (variableInformation.statePairToIndexMap.find(std::make_pair(state, entry.first)) != variableInformation.statePairToIndexMap.end()) { + if (variableInformation.statePairToIndexMap.find(std::make_pair(state, entry.getColumn())) != variableInformation.statePairToIndexMap.end()) { continue; } // At this point we know that the state-pair does not have an associated variable. - variableInformation.statePairToIndexMap[std::make_pair(state, entry.first)] = variableInformation.statePairVariables.size(); + variableInformation.statePairToIndexMap[std::make_pair(state, entry.getColumn())] = variableInformation.statePairVariables.size(); // Clear contents of the stream to construct new expression name. variableName.clear(); variableName.str(""); - variableName << "t" << state << "_" << entry.first; + variableName << "t" << state << "_" << entry.getColumn(); variableInformation.statePairVariables.push_back(context.bool_const(variableName.str().c_str())); } @@ -258,7 +260,7 @@ namespace storm { * @param solver The solver in which to assert the constraints. * @param variableInformation A structure with information about the variables for the labels. */ - static void assertFuMalikInitialConstraints(storm::ir::Program const& program, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation) { + static void assertFuMalikInitialConstraints(storm::prism::Program const& program, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, z3::context& context, z3::solver& solver, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation) { // Assert that at least one of the labels must be taken. z3::expr formula = variableInformation.labelVariables.at(0); for (uint_fast64_t index = 1; index < variableInformation.labelVariables.size(); ++index) { @@ -316,11 +318,11 @@ namespace storm { // Iterate over successors and add relevant choices of relevant successors to the following label set. bool canReachTargetState = false; for (auto const& entry : transitionMatrix.getRow(currentChoice)) { - if (relevancyInformation.relevantStates.get(entry.first)) { - for (auto relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(entry.first)) { + if (relevancyInformation.relevantStates.get(entry.getColumn())) { + for (auto relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(entry.getColumn())) { followingLabels[choiceLabeling[currentChoice]].insert(choiceLabeling[currentChoice]); } - } else if (psiStates.get(entry.first)) { + } else if (psiStates.get(entry.getColumn())) { canReachTargetState = true; } } @@ -335,11 +337,11 @@ namespace storm { // Iterate over predecessors and add all choices that target the current state to the preceding // label set of all labels of all relevant choices of the current state. for (auto const& predecessorEntry : backwardTransitions.getRow(currentState)) { - if (relevancyInformation.relevantStates.get(predecessorEntry.first)) { - for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(predecessorEntry.first)) { + if (relevancyInformation.relevantStates.get(predecessorEntry.getColumn())) { + for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(predecessorEntry.getColumn())) { bool choiceTargetsCurrentState = false; for (auto const& successorEntry : transitionMatrix.getRow(predecessorChoice)) { - if (successorEntry.first == currentState) { + if (successorEntry.getColumn() == currentState) { choiceTargetsCurrentState = true; } } @@ -554,8 +556,7 @@ namespace storm { * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. */ - static void assertSymbolicCuts(storm::ir::Program const& program, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { - + static void assertSymbolicCuts(storm::prism::Program const& program, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { // A container storing the label sets that may precede a given label set. std::map, std::set>> precedingLabelSets; @@ -581,11 +582,11 @@ namespace storm { // Iterate over predecessors and add all choices that target the current state to the preceding // label set of all labels of all relevant choices of the current state. for (auto const& predecessorEntry : backwardTransitions.getRow(currentState)) { - if (relevancyInformation.relevantStates.get(predecessorEntry.first)) { - for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(predecessorEntry.first)) { + if (relevancyInformation.relevantStates.get(predecessorEntry.getColumn())) { + for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(predecessorEntry.getColumn())) { bool choiceTargetsCurrentState = false; for (auto const& successorEntry : transitionMatrix.getRow(predecessorChoice)) { - if (successorEntry.first == currentState) { + if (successorEntry.getColumn() == currentState) { choiceTargetsCurrentState = true; } } @@ -599,45 +600,50 @@ namespace storm { } } - storm::utility::ir::VariableInformation programVariableInformation = storm::utility::ir::createVariableInformation(program); - // Create a context and register all variables of the program with their correct type. z3::context localContext; + z3::solver localSolver(localContext); std::map solverVariables; - for (auto const& booleanVariable : programVariableInformation.booleanVariables) { + for (auto const& booleanVariable : program.getGlobalBooleanVariables()) { solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); } - for (auto const& integerVariable : programVariableInformation.integerVariables) { + for (auto const& integerVariable : program.getGlobalIntegerVariables()) { solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); } - // Now create a corresponding local solver and assert all range bounds for the integer variables. - z3::solver localSolver(localContext); + for (auto const& module : program.getModules()) { + for (auto const& booleanVariable : module.getBooleanVariables()) { + solverVariables.emplace(booleanVariable.getName(), localContext.bool_const(booleanVariable.getName().c_str())); + } + for (auto const& integerVariable : module.getIntegerVariables()) { + solverVariables.emplace(integerVariable.getName(), localContext.int_const(integerVariable.getName().c_str())); + } + } + storm::adapters::Z3ExpressionAdapter expressionAdapter(localContext, solverVariables); - for (auto const& integerVariable : programVariableInformation.integerVariables) { - z3::expr lowerBound = expressionAdapter.translateExpression(integerVariable.getLowerBound()); + + // Then add the constraints for bounds of the integer variables.. + for (auto const& integerVariable : program.getGlobalIntegerVariables()) { + z3::expr lowerBound = expressionAdapter.translateExpression(integerVariable.getLowerBoundExpression()); lowerBound = solverVariables.at(integerVariable.getName()) >= lowerBound; localSolver.add(lowerBound); - - z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBound()); + z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBoundExpression()); upperBound = solverVariables.at(integerVariable.getName()) <= upperBound; localSolver.add(upperBound); } - - // Construct an expression that exactly characterizes the initial state. - std::unique_ptr initialState(storm::utility::ir::getInitialState(program, programVariableInformation)); - z3::expr initialStateExpression = localContext.bool_val(true); - for (uint_fast64_t index = 0; index < programVariableInformation.booleanVariables.size(); ++index) { - if (std::get<0>(*initialState).at(programVariableInformation.booleanVariableToIndexMap.at(programVariableInformation.booleanVariables[index].getName()))) { - initialStateExpression = initialStateExpression && solverVariables.at(programVariableInformation.booleanVariables[index].getName()); - } else { - initialStateExpression = initialStateExpression && !solverVariables.at(programVariableInformation.booleanVariables[index].getName()); + for (auto const& module : program.getModules()) { + for (auto const& integerVariable : module.getIntegerVariables()) { + z3::expr lowerBound = expressionAdapter.translateExpression(integerVariable.getLowerBoundExpression()); + lowerBound = solverVariables.at(integerVariable.getName()) >= lowerBound; + localSolver.add(lowerBound); + z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBoundExpression()); + upperBound = solverVariables.at(integerVariable.getName()) <= upperBound; + localSolver.add(upperBound); } } - for (uint_fast64_t index = 0; index < programVariableInformation.integerVariables.size(); ++index) { - storm::ir::IntegerVariable const& variable = programVariableInformation.integerVariables[index]; - initialStateExpression = initialStateExpression && (solverVariables.at(variable.getName()) == localContext.int_val(static_cast(std::get<1>(*initialState).at(programVariableInformation.integerVariableToIndexMap.at(variable.getName()))))); - } + + // Construct an expression that exactly characterizes the initial state. + storm::expressions::Expression initialStateExpression = program.getInitialConstruct().getInitialStatesExpression(); // Store the found implications in a container similar to the preceding label sets. std::map, std::set>> backwardImplications; @@ -646,12 +652,12 @@ namespace storm { for (auto const& labelSetAndPrecedingLabelSetsPair : precedingLabelSets) { // Find out the commands for the currently considered label set. - std::vector> currentCommandVector; + std::vector> currentCommandVector; for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) { - storm::ir::Module const& module = program.getModule(moduleIndex); + storm::prism::Module const& module = program.getModule(moduleIndex); for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) { - storm::ir::Command const& command = module.getCommand(commandIndex); + storm::prism::Command const& command = module.getCommand(commandIndex); // If the current command is one of the commands we need to consider, store a reference to it in the container. if (labelSetAndPrecedingLabelSetsPair.first.find(command.getGlobalIndex()) != labelSetAndPrecedingLabelSetsPair.first.end()) { @@ -665,9 +671,9 @@ namespace storm { // Check if the command set is enabled in the initial state. for (auto const& command : currentCommandVector) { - localSolver.add(expressionAdapter.translateExpression(command.get().getGuard())); + localSolver.add(expressionAdapter.translateExpression(command.get().getGuardExpression())); } - localSolver.add(initialStateExpression); + localSolver.add(expressionAdapter.translateExpression(initialStateExpression)); z3::check_result checkResult = localSolver.check(); localSolver.pop(); @@ -676,19 +682,19 @@ namespace storm { // If the solver reports unsat, then we know that the current selection is not enabled in the initial state. if (checkResult == z3::unsat) { LOG4CPLUS_DEBUG(logger, "Selection not enabled in initial state."); - std::unique_ptr guardConjunction; + storm::expressions::Expression guardConjunction; if (currentCommandVector.size() == 1) { - guardConjunction = currentCommandVector.begin()->get().getGuard()->clone(); + guardConjunction = currentCommandVector.begin()->get().getGuardExpression(); } else if (currentCommandVector.size() > 1) { - std::vector>::const_iterator setIterator = currentCommandVector.begin(); - std::unique_ptr first = setIterator->get().getGuard()->clone(); + std::vector>::const_iterator setIterator = currentCommandVector.begin(); + storm::expressions::Expression first = setIterator->get().getGuardExpression(); ++setIterator; - std::unique_ptr second = setIterator->get().getGuard()->clone(); - guardConjunction = std::unique_ptr(new storm::ir::expressions::BinaryBooleanFunctionExpression(std::move(first), std::move(second), storm::ir::expressions::BinaryBooleanFunctionExpression::AND)); + storm::expressions::Expression second = setIterator->get().getGuardExpression(); + guardConjunction = first && second; ++setIterator; while (setIterator != currentCommandVector.end()) { - guardConjunction = std::unique_ptr(new storm::ir::expressions::BinaryBooleanFunctionExpression(std::move(guardConjunction), setIterator->get().getGuard()->clone(), storm::ir::expressions::BinaryBooleanFunctionExpression::AND)); + guardConjunction = guardConjunction && setIterator->get().getGuardExpression(); ++setIterator; } } else { @@ -701,9 +707,9 @@ namespace storm { bool firstAssignment = true; for (auto const& command : currentCommandVector) { if (firstAssignment) { - guardExpression = !expressionAdapter.translateExpression(command.get().getGuard()); + guardExpression = !expressionAdapter.translateExpression(command.get().getGuardExpression()); } else { - guardExpression = guardExpression | !expressionAdapter.translateExpression(command.get().getGuard()); + guardExpression = guardExpression | !expressionAdapter.translateExpression(command.get().getGuardExpression()); } } localSolver.add(guardExpression); @@ -715,12 +721,12 @@ namespace storm { localSolver.push(); // Find out the commands for the currently considered preceding label set. - std::vector> currentPrecedingCommandVector; + std::vector> currentPrecedingCommandVector; for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) { - storm::ir::Module const& module = program.getModule(moduleIndex); + storm::prism::Module const& module = program.getModule(moduleIndex); for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) { - storm::ir::Command const& command = module.getCommand(commandIndex); + storm::prism::Command const& command = module.getCommand(commandIndex); // If the current command is one of the commands we need to consider, store a reference to it in the container. if (precedingLabelSet.find(command.getGlobalIndex()) != precedingLabelSet.end()) { @@ -731,10 +737,10 @@ namespace storm { // Assert all the guards of the preceding command set. for (auto const& command : currentPrecedingCommandVector) { - localSolver.add(expressionAdapter.translateExpression(command.get().getGuard())); + localSolver.add(expressionAdapter.translateExpression(command.get().getGuardExpression())); } - std::vector::const_iterator> iteratorVector; + std::vector::const_iterator> iteratorVector; for (auto const& command : currentPrecedingCommandVector) { iteratorVector.push_back(command.get().getUpdates().begin()); } @@ -743,13 +749,15 @@ namespace storm { std::vector formulae; bool done = false; while (!done) { - std::vector> currentUpdateCombination; + std::map currentUpdateCombinationMap; for (auto const& updateIterator : iteratorVector) { - currentUpdateCombination.push_back(*updateIterator); + for (auto const& assignment : updateIterator->getAssignments()) { + currentUpdateCombinationMap.emplace(assignment.getVariableName(), assignment.getExpression()); + } } LOG4CPLUS_DEBUG(logger, "About to assert a weakest precondition."); - std::unique_ptr wp = storm::utility::ir::getWeakestPrecondition(guardConjunction->clone(), currentUpdateCombination); + storm::expressions::Expression wp = guardConjunction.substitute(currentUpdateCombinationMap); formulae.push_back(expressionAdapter.translateExpression(wp)); LOG4CPLUS_DEBUG(logger, "Asserted weakest precondition."); @@ -913,16 +921,16 @@ namespace storm { // Assert the constraints (1). boost::container::flat_set relevantPredecessors; for (auto const& predecessorEntry : backwardTransitions.getRow(relevantState)) { - if (relevantState != predecessorEntry.first && relevancyInformation.relevantStates.get(predecessorEntry.first)) { - relevantPredecessors.insert(predecessorEntry.first); + if (relevantState != predecessorEntry.getColumn() && relevancyInformation.relevantStates.get(predecessorEntry.getColumn())) { + relevantPredecessors.insert(predecessorEntry.getColumn()); } } boost::container::flat_set relevantSuccessors; for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(relevantState)) { for (auto const& successorEntry : transitionMatrix.getRow(relevantChoice)) { - if (relevantState != successorEntry.first && (relevancyInformation.relevantStates.get(successorEntry.first) || psiStates.get(successorEntry.first))) { - relevantSuccessors.insert(successorEntry.first); + if (relevantState != successorEntry.getColumn() && (relevancyInformation.relevantStates.get(successorEntry.getColumn()) || psiStates.get(successorEntry.getColumn()))) { + relevantSuccessors.insert(successorEntry.getColumn()); } } } @@ -941,8 +949,8 @@ namespace storm { boost::container::flat_set relevantSuccessors; for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(relevantState)) { for (auto const& successorEntry : transitionMatrix.getRow(relevantChoice)) { - if (relevantState != successorEntry.first && (relevancyInformation.relevantStates.get(successorEntry.first) || psiStates.get(successorEntry.first))) { - relevantSuccessors.insert(successorEntry.first); + if (relevantState != successorEntry.getColumn() && (relevancyInformation.relevantStates.get(successorEntry.getColumn()) || psiStates.get(successorEntry.getColumn()))) { + relevantSuccessors.insert(successorEntry.getColumn()); } } } @@ -965,7 +973,7 @@ namespace storm { boost::container::flat_set choicesForStatePair; for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(sourceState)) { for (auto const& successorEntry : transitionMatrix.getRow(relevantChoice)) { - if (successorEntry.first == targetState) { + if (successorEntry.getColumn() == targetState) { choicesForStatePair.insert(relevantChoice); } } @@ -1400,13 +1408,13 @@ namespace storm { bool choiceTargetsRelevantState = false; for (auto const& successorEntry : transitionMatrix.getRow(currentChoice)) { - if (relevancyInformation.relevantStates.get(successorEntry.first) && currentState != successorEntry.first) { + if (relevancyInformation.relevantStates.get(successorEntry.getColumn()) && currentState != successorEntry.getColumn()) { choiceTargetsRelevantState = true; - if (!reachableStates.get(successorEntry.first)) { - reachableStates.set(successorEntry.first); - stack.push_back(successorEntry.first); + if (!reachableStates.get(successorEntry.getColumn())) { + reachableStates.set(successorEntry.getColumn()); + stack.push_back(successorEntry.getColumn()); } - } else if (psiStates.get(successorEntry.first)) { + } else if (psiStates.get(successorEntry.getColumn())) { targetStateIsReachable = true; } } @@ -1446,7 +1454,7 @@ namespace storm { // Determine whether the state has the option to leave the reachable state space and go to the unreachable relevant states. for (auto const& successorEntry : originalMdp.getTransitionMatrix().getRow(currentChoice)) { - if (unreachableRelevantStates.get(successorEntry.first)) { + if (unreachableRelevantStates.get(successorEntry.getColumn())) { isBorderChoice = true; } } @@ -1529,13 +1537,13 @@ namespace storm { bool choiceTargetsRelevantState = false; for (auto const& successorEntry : transitionMatrix.getRow(currentChoice)) { - if (relevancyInformation.relevantStates.get(successorEntry.first) && currentState != successorEntry.first) { + if (relevancyInformation.relevantStates.get(successorEntry.getColumn()) && currentState != successorEntry.getColumn()) { choiceTargetsRelevantState = true; - if (!reachableStates.get(successorEntry.first)) { - reachableStates.set(successorEntry.first, true); - stack.push_back(successorEntry.first); + if (!reachableStates.get(successorEntry.getColumn())) { + reachableStates.set(successorEntry.getColumn(), true); + stack.push_back(successorEntry.getColumn()); } - } else if (psiStates.get(successorEntry.first)) { + } else if (psiStates.get(successorEntry.getColumn())) { targetStateIsReachable = true; } } @@ -1613,7 +1621,7 @@ namespace storm { * @param checkThresholdFeasible If set, it is verified that the model can actually achieve/exceed the given probability value. If this check * is made and fails, an exception is thrown. */ - static boost::container::flat_set getMinimalCommandSet(storm::ir::Program program, std::string const& constantDefinitionString, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool strictBound, bool checkThresholdFeasible = false, bool includeReachabilityEncoding = false) { + static boost::container::flat_set getMinimalCommandSet(storm::prism::Program program, std::string const& constantDefinitionString, storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool strictBound, bool checkThresholdFeasible = false, bool includeReachabilityEncoding = false) { #ifdef STORM_HAVE_Z3 // Set up all clocks used for time measurement. auto totalClock = std::chrono::high_resolution_clock::now(); @@ -1632,7 +1640,9 @@ namespace storm { auto analysisClock = std::chrono::high_resolution_clock::now(); decltype(std::chrono::high_resolution_clock::now() - analysisClock) totalAnalysisTime(0); - storm::utility::ir::defineUndefinedConstants(program, constantDefinitionString); + std::map constantDefinitions = storm::utility::prism::parseConstantDefinitionString(program, constantDefinitionString); + storm::prism::Program preparedProgram = program.defineUndefinedConstants(constantDefinitions); + preparedProgram = preparedProgram.substituteConstants(); // (0) Check whether the MDP is indeed labeled. if (!labeledMdp.hasChoiceLabeling()) { @@ -1676,7 +1686,7 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Asserting cuts."); assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); LOG4CPLUS_DEBUG(logger, "Asserted explicit cuts."); - assertSymbolicCuts(program, labeledMdp, variableInformation, relevancyInformation, context, solver); + assertSymbolicCuts(preparedProgram, labeledMdp, variableInformation, relevancyInformation, context, solver); LOG4CPLUS_DEBUG(logger, "Asserted symbolic cuts."); if (includeReachabilityEncoding) { assertReachabilityCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); @@ -1764,8 +1774,6 @@ namespace storm { std::cout << " * number of models that could not reach a target state: " << zeroProbabilityCount << " (" << 100 * static_cast(zeroProbabilityCount)/iterations << "%)" << std::endl << std::endl; } - // (9) Return the resulting command set after undefining the constants. - storm::utility::ir::undefineUndefinedConstants(program); return commandSet; #else @@ -1773,7 +1781,7 @@ namespace storm { #endif } - static void computeCounterexample(storm::ir::Program program, std::string const& constantDefinitionString, storm::models::Mdp const& labeledMdp, storm::property::prctl::AbstractPrctlFormula const* formulaPtr) { + static void computeCounterexample(storm::prism::Program program, std::string const& constantDefinitionString, storm::models::Mdp const& labeledMdp, storm::property::prctl::AbstractPrctlFormula const* formulaPtr) { #ifdef STORM_HAVE_Z3 std::cout << std::endl << "Generating minimal label counterexample for formula " << formulaPtr->toString() << std::endl; // First, we need to check whether the current formula is an Until-Formula. @@ -1821,9 +1829,8 @@ namespace storm { std::cout << std::endl << "Computed minimal label set of size " << labelSet.size() << " in " << std::chrono::duration_cast(endTime - startTime).count() << "ms." << std::endl; std::cout << "Resulting program:" << std::endl << std::endl; - storm::ir::Program restrictedProgram(program); - restrictedProgram.restrictCommands(labelSet); - std::cout << restrictedProgram.toString() << std::endl; + storm::prism::Program restrictedProgram = program.restrictCommands(labelSet); + std::cout << restrictedProgram << std::endl; std::cout << std::endl << "-------------------------------------------" << std::endl; #else diff --git a/src/exceptions/ExceptionMacros.h b/src/exceptions/ExceptionMacros.h index 09e7058ec..3840928fe 100644 --- a/src/exceptions/ExceptionMacros.h +++ b/src/exceptions/ExceptionMacros.h @@ -22,10 +22,10 @@ extern log4cplus::Logger logger; #define LOG_THROW(cond, exception, message) \ { \ -if (!(cond)) { \ -LOG4CPLUS_ERROR(logger, message); \ -throw exception() << message; \ -} \ +if (!(cond)) { \ +LOG4CPLUS_ERROR(logger, message); \ +throw exception() << message; \ +} \ } while (false) #endif /* STORM_EXCEPTIONS_EXCEPTIONMACROS_H_ */ \ No newline at end of file diff --git a/src/modelchecker/csl/SparseMarkovAutomatonCslModelChecker.h b/src/modelchecker/csl/SparseMarkovAutomatonCslModelChecker.h index dbcf4d8b1..e634ad597 100644 --- a/src/modelchecker/csl/SparseMarkovAutomatonCslModelChecker.h +++ b/src/modelchecker/csl/SparseMarkovAutomatonCslModelChecker.h @@ -302,8 +302,8 @@ namespace storm { // Finally, we are ready to create the SSP matrix and right-hand side of the SSP. std::vector b; - typename storm::storage::SparseMatrixBuilder sspMatrixBuilder(0, 0, 0, true, numberOfStatesNotInMecs + mecDecomposition.size() + 1); - + typename storm::storage::SparseMatrixBuilder sspMatrixBuilder(0, 0, 0, true, numberOfStatesNotInMecs + mecDecomposition.size()); + // If the source state is not contained in any MEC, we copy its choices (and perform the necessary modifications). uint_fast64_t currentChoice = 0; for (auto state : statesNotContainedInAnyMec) { @@ -385,7 +385,7 @@ namespace storm { } // Finalize the matrix and solve the corresponding system of equations. - storm::storage::SparseMatrix sspMatrix = sspMatrixBuilder.build(currentChoice + 1); + storm::storage::SparseMatrix sspMatrix = sspMatrixBuilder.build(currentChoice); std::vector x(numberOfStatesNotInMecs + mecDecomposition.size()); nondeterministicLinearEquationSolver->solveEquationSystem(min, sspMatrix, x, b); diff --git a/src/models/MarkovAutomaton.h b/src/models/MarkovAutomaton.h index 89ce7ccee..387590df6 100644 --- a/src/models/MarkovAutomaton.h +++ b/src/models/MarkovAutomaton.h @@ -126,7 +126,7 @@ namespace storm { //uint_fast64_t newNumberOfRows = this->getNumberOfChoices() - numberOfHybridStates; // Create the matrix for the new transition relation and the corresponding nondeterministic choice vector. - storm::storage::SparseMatrixBuilder newTransitionMatrixBuilder(0, 0, 0, true, this->getNumberOfStates() + 1); + storm::storage::SparseMatrixBuilder newTransitionMatrixBuilder(0, 0, 0, true, this->getNumberOfStates()); // Now copy over all choices that need to be kept. uint_fast64_t currentChoice = 0; diff --git a/src/models/Mdp.h b/src/models/Mdp.h index 61858fefa..de5e3d78b 100644 --- a/src/models/Mdp.h +++ b/src/models/Mdp.h @@ -139,7 +139,7 @@ public: std::vector> const& choiceLabeling = this->getChoiceLabeling(); - storm::storage::SparseMatrixBuilder transitionMatrixBuilder; + storm::storage::SparseMatrixBuilder transitionMatrixBuilder(0, 0, 0, true); std::vector> newChoiceLabeling; // Check for each choice of each state, whether the choice labels are fully contained in the given label set. @@ -156,7 +156,7 @@ public: } stateHasValidChoice = true; for (auto const& entry : this->getTransitionMatrix().getRow(choice)) { - transitionMatrixBuilder.addNextValue(currentRow, entry.first, entry.second); + transitionMatrixBuilder.addNextValue(currentRow, entry.getColumn(), entry.getValue()); } newChoiceLabeling.emplace_back(choiceLabeling[choice]); ++currentRow; diff --git a/src/parser/ExpressionParser.cpp b/src/parser/ExpressionParser.cpp index 6459d7a9d..ed6d1d642 100644 --- a/src/parser/ExpressionParser.cpp +++ b/src/parser/ExpressionParser.cpp @@ -5,50 +5,50 @@ namespace storm { namespace parser { - ExpressionParser::ExpressionParser(qi::symbols const& invalidIdentifiers_) : ExpressionParser::base_type(expression), createExpressions(false), acceptDoubleLiterals(true), identifiers_(nullptr), invalidIdentifiers_(invalidIdentifiers_) { + ExpressionParser::ExpressionParser(qi::symbols const& invalidIdentifiers_) : ExpressionParser::base_type(expression), orOperator_(), andOperator_(), equalityOperator_(), relationalOperator_(), plusOperator_(), multiplicationOperator_(), powerOperator_(), unaryOperator_(), floorCeilOperator_(), minMaxOperator_(), trueFalse_(), createExpressions(false), acceptDoubleLiterals(true), identifiers_(nullptr), invalidIdentifiers_(invalidIdentifiers_) { identifier %= qi::as_string[qi::raw[qi::lexeme[((qi::alpha | qi::char_('_')) >> *(qi::alnum | qi::char_('_')))]]][qi::_pass = phoenix::bind(&ExpressionParser::isValidIdentifier, phoenix::ref(*this), qi::_1)]; identifier.name("identifier"); - floorCeilExpression = ((qi::lit("floor")[qi::_a = true] | qi::lit("ceil")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createFloorExpression, phoenix::ref(*this), qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createCeilExpression, phoenix::ref(*this), qi::_1)]]; + floorCeilExpression = ((floorCeilOperator_ >> qi::lit("(")) > plusExpression > qi::lit(")"))[qi::_val = phoenix::bind(&ExpressionParser::createFloorCeilExpression, phoenix::ref(*this), qi::_1, qi::_2)]; floorCeilExpression.name("floor/ceil expression"); - minMaxExpression = ((qi::lit("min")[qi::_a = true] | qi::lit("max")[qi::_a = false]) >> qi::lit("(") >> plusExpression >> qi::lit(",") >> plusExpression >> qi::lit(")"))[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createMinimumExpression, phoenix::ref(*this), qi::_1, qi::_2)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createMaximumExpression, phoenix::ref(*this), qi::_1, qi::_2)]]; + minMaxExpression = ((minMaxOperator_ >> qi::lit("(")) > plusExpression > qi::lit(",") > plusExpression > qi::lit(")"))[qi::_val = phoenix::bind(&ExpressionParser::createMinimumMaximumExpression, phoenix::ref(*this), qi::_2, qi::_1, qi::_3)]; minMaxExpression.name("min/max expression"); identifierExpression = identifier[qi::_val = phoenix::bind(&ExpressionParser::getIdentifierExpression, phoenix::ref(*this), qi::_1)]; identifierExpression.name("identifier expression"); - literalExpression = qi::lit("true")[qi::_val = phoenix::bind(&ExpressionParser::createTrueExpression, phoenix::ref(*this))] | qi::lit("false")[qi::_val = phoenix::bind(&ExpressionParser::createFalseExpression, phoenix::ref(*this))] | strict_double[qi::_val = phoenix::bind(&ExpressionParser::createDoubleLiteralExpression, phoenix::ref(*this), qi::_1, qi::_pass)] | qi::int_[qi::_val = phoenix::bind(&ExpressionParser::createIntegerLiteralExpression, phoenix::ref(*this), qi::_1)]; + literalExpression = trueFalse_[qi::_val = qi::_1] | strict_double[qi::_val = phoenix::bind(&ExpressionParser::createDoubleLiteralExpression, phoenix::ref(*this), qi::_1, qi::_pass)] | qi::int_[qi::_val = phoenix::bind(&ExpressionParser::createIntegerLiteralExpression, phoenix::ref(*this), qi::_1)]; literalExpression.name("literal expression"); - atomicExpression = minMaxExpression | floorCeilExpression | qi::lit("(") >> expression >> qi::lit(")") | literalExpression | identifierExpression; + atomicExpression = floorCeilExpression | minMaxExpression | (qi::lit("(") >> expression >> qi::lit(")")) | literalExpression | identifierExpression; atomicExpression.name("atomic expression"); - unaryExpression = atomicExpression[qi::_val = qi::_1] | (qi::lit("!") >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionParser::createNotExpression, phoenix::ref(*this), qi::_1)] | (qi::lit("-") >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionParser::createMinusExpression, phoenix::ref(*this), qi::_1)]; + unaryExpression = (-unaryOperator_ >> atomicExpression)[qi::_val = phoenix::bind(&ExpressionParser::createUnaryExpression, phoenix::ref(*this), qi::_1, qi::_2)]; unaryExpression.name("unary expression"); - powerExpression = unaryExpression[qi::_val = qi::_1] >> -(qi::lit("^") > expression)[qi::_val = phoenix::bind(&ExpressionParser::createPowerExpression, phoenix::ref(*this), qi::_val, qi::_1)]; + powerExpression = unaryExpression[qi::_val = qi::_1] > -(powerOperator_ > expression)[qi::_val = phoenix::bind(&ExpressionParser::createPowerExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; powerExpression.name("power expression"); - multiplicationExpression = powerExpression[qi::_val = qi::_1] >> *((qi::lit("*")[qi::_a = true] | qi::lit("/")[qi::_a = false]) >> powerExpression[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createMultExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createDivExpression, phoenix::ref(*this), qi::_val, qi::_1)]]); + multiplicationExpression = powerExpression[qi::_val = qi::_1] > *(multiplicationOperator_ > powerExpression)[qi::_val = phoenix::bind(&ExpressionParser::createMultExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; multiplicationExpression.name("multiplication expression"); - plusExpression = multiplicationExpression[qi::_val = qi::_1] >> *((qi::lit("+")[qi::_a = true] | qi::lit("-")[qi::_a = false]) >> multiplicationExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createPlusExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createMinusExpression, phoenix::ref(*this), qi::_val, qi::_1)]]; + plusExpression = multiplicationExpression[qi::_val = qi::_1] > *(plusOperator_ >> multiplicationExpression)[qi::_val = phoenix::bind(&ExpressionParser::createPlusExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; plusExpression.name("plus expression"); - relativeExpression = (plusExpression >> qi::lit(">=") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createGreaterOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit(">") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createGreaterExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<=") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createLessOrEqualExpression, phoenix::ref(*this), qi::_1, qi::_2)] | (plusExpression >> qi::lit("<") >> plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createLessExpression, phoenix::ref(*this), qi::_1, qi::_2)] | plusExpression[qi::_val = qi::_1]; + relativeExpression = plusExpression[qi::_val = qi::_1] > -(relationalOperator_ > plusExpression)[qi::_val = phoenix::bind(&ExpressionParser::createRelationalExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; relativeExpression.name("relative expression"); - - equalityExpression = relativeExpression[qi::_val = qi::_1] >> *((qi::lit("=")[qi::_a = true] | qi::lit("!=")[qi::_a = false]) >> relativeExpression)[phoenix::if_(qi::_a) [ qi::_val = phoenix::bind(&ExpressionParser::createEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] .else_ [ qi::_val = phoenix::bind(&ExpressionParser::createNotEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1) ] ]; + + equalityExpression = relativeExpression[qi::_val = qi::_1] >> *(equalityOperator_ >> relativeExpression)[qi::_val = phoenix::bind(&ExpressionParser::createEqualsExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; equalityExpression.name("equality expression"); - andExpression = equalityExpression[qi::_val = qi::_1] >> *(qi::lit("&") >> equalityExpression)[qi::_val = phoenix::bind(&ExpressionParser::createAndExpression, phoenix::ref(*this), qi::_val, qi::_1)]; + andExpression = equalityExpression[qi::_val = qi::_1] >> *(andOperator_ > equalityExpression)[qi::_val = phoenix::bind(&ExpressionParser::createAndExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; andExpression.name("and expression"); - orExpression = andExpression[qi::_val = qi::_1] >> *((qi::lit("|")[qi::_a = true] | qi::lit("=>")[qi::_a = false]) >> andExpression)[phoenix::if_(qi::_a) [qi::_val = phoenix::bind(&ExpressionParser::createOrExpression, phoenix::ref(*this), qi::_val, qi::_1)] .else_ [qi::_val = phoenix::bind(&ExpressionParser::createImpliesExpression, phoenix::ref(*this), qi::_val, qi::_1)] ]; + orExpression = andExpression[qi::_val = qi::_1] > *(orOperator_ > andExpression)[qi::_val = phoenix::bind(&ExpressionParser::createOrExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; orExpression.name("or expression"); - iteExpression = orExpression[qi::_val = qi::_1] >> -(qi::lit("?") > orExpression > qi::lit(":") > orExpression)[qi::_val = phoenix::bind(&ExpressionParser::createIteExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; + iteExpression = orExpression[qi::_val = qi::_1] > -(qi::lit("?") > orExpression > qi::lit(":") > orExpression)[qi::_val = phoenix::bind(&ExpressionParser::createIteExpression, phoenix::ref(*this), qi::_val, qi::_1, qi::_2)]; iteExpression.name("if-then-else expression"); expression %= iteExpression; @@ -97,219 +97,134 @@ namespace storm { } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1.implies(e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 || e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try{ - return e1 && e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 > e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 >= e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 < e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 <= e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - - storm::expressions::Expression ExpressionParser::createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + + storm::expressions::Expression ExpressionParser::createOrExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - if (e1.hasBooleanReturnType() && e2.hasBooleanReturnType()) { - return e1.iff(e2); - } else { - return e1 == e2; + switch (operatorType) { + case storm::expressions::OperatorType::Or: return e1 || e2; break; + case storm::expressions::OperatorType::Implies: return e1.implies(e2); break; + default: LOG_ASSERT(false, "Invalid operation."); break; } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return e1 != e2; - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createAndExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return e1 + e2; + switch (operatorType) { + case storm::expressions::OperatorType::And: return e1 && e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createRelationalExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return e1 - e2; + switch (operatorType) { + case storm::expressions::OperatorType::GreaterOrEqual: return e1 >= e2; break; + case storm::expressions::OperatorType::Greater: return e1 > e2; break; + case storm::expressions::OperatorType::LessOrEqual: return e1 <= e2; break; + case storm::expressions::OperatorType::Less: return e1 < e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createEqualsExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return e1 * e2; + switch (operatorType) { + case storm::expressions::OperatorType::Equal: return e1.hasBooleanReturnType() && e2.hasBooleanReturnType() ? e1.iff(e2) : e1 == e2; break; + case storm::expressions::OperatorType::NotEqual: return e1 != e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createPowerExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createPlusExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return e1 ^ e2; + switch (operatorType) { + case storm::expressions::OperatorType::Plus: return e1 + e2; break; + case storm::expressions::OperatorType::Minus: return e1 - e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createMultExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return e1 / e2; + switch (operatorType) { + case storm::expressions::OperatorType::Times: return e1 * e2; break; + case storm::expressions::OperatorType::Divide: return e1 / e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - storm::expressions::Expression ExpressionParser::createNotExpression(storm::expressions::Expression e1) const { + storm::expressions::Expression ExpressionParser::createPowerExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return !e1; + switch (operatorType) { + case storm::expressions::OperatorType::Power: return e1 ^ e2; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - - storm::expressions::Expression ExpressionParser::createMinusExpression(storm::expressions::Expression e1) const { + + storm::expressions::Expression ExpressionParser::createUnaryExpression(boost::optional const& operatorType, storm::expressions::Expression const& e1) const { if (this->createExpressions) { try { - return -e1; + if (operatorType) { + switch (operatorType.get()) { + case storm::expressions::OperatorType::Not: return !e1; break; + case storm::expressions::OperatorType::Minus: return -e1; break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } + } else { + return e1; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } - } - - storm::expressions::Expression ExpressionParser::createTrueExpression() const { - if (this->createExpressions) { - return storm::expressions::Expression::createTrue(); - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createFalseExpression() const { return storm::expressions::Expression::createFalse(); } - + storm::expressions::Expression ExpressionParser::createDoubleLiteralExpression(double value, bool& pass) const { // If we are not supposed to accept double expressions, we reject it by setting pass to false. if (!this->acceptDoubleLiterals) { @@ -331,52 +246,34 @@ namespace storm { } } - storm::expressions::Expression ExpressionParser::createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { - if (this->createExpressions) { - try { - return storm::expressions::Expression::minimum(e1, e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const { + storm::expressions::Expression ExpressionParser::createMinimumMaximumExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const { if (this->createExpressions) { try { - return storm::expressions::Expression::maximum(e1, e2); - } catch (storm::exceptions::InvalidTypeException const& e) { - LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); - } - } else { - return storm::expressions::Expression::createFalse(); - } - } - - storm::expressions::Expression ExpressionParser::createFloorExpression(storm::expressions::Expression e1) const { - if (this->createExpressions) { - try { - return e1.floor(); + switch (operatorType) { + case storm::expressions::OperatorType::Min: return storm::expressions::Expression::minimum(e1, e2); break; + case storm::expressions::OperatorType::Max: return storm::expressions::Expression::maximum(e1, e2); break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } - - storm::expressions::Expression ExpressionParser::createCeilExpression(storm::expressions::Expression e1) const { + + storm::expressions::Expression ExpressionParser::createFloorCeilExpression(storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e1) const { if (this->createExpressions) { try { - return e1.ceil(); + switch (operatorType) { + case storm::expressions::OperatorType::Floor: return e1.floor(); break; + case storm::expressions::OperatorType::Ceil: return e1.ceil(); break; + default: LOG_ASSERT(false, "Invalid operation."); break; + } } catch (storm::exceptions::InvalidTypeException const& e) { LOG_THROW(false, storm::exceptions::WrongFormatException, "Parsing error in line " << get_line(qi::_3) << ": " << e.what()); } - } else { - return storm::expressions::Expression::createFalse(); } + return storm::expressions::Expression::createFalse(); } storm::expressions::Expression ExpressionParser::getIdentifierExpression(std::string const& identifier) const { diff --git a/src/parser/ExpressionParser.h b/src/parser/ExpressionParser.h index 380644caf..423227aa2 100644 --- a/src/parser/ExpressionParser.h +++ b/src/parser/ExpressionParser.h @@ -43,6 +43,127 @@ namespace storm { void setAcceptDoubleLiterals(bool flag); private: + struct orOperatorStruct : qi::symbols { + orOperatorStruct() { + add + ("|", storm::expressions::OperatorType::Or) + ("=>", storm::expressions::OperatorType::Implies); + } + }; + + // A parser used for recognizing the operators at the "or" precedence level. + orOperatorStruct orOperator_; + + struct andOperatorStruct : qi::symbols { + andOperatorStruct() { + add + ("&", storm::expressions::OperatorType::And); + } + }; + + // A parser used for recognizing the operators at the "and" precedence level. + andOperatorStruct andOperator_; + + struct equalityOperatorStruct : qi::symbols { + equalityOperatorStruct() { + add + ("=", storm::expressions::OperatorType::Equal) + ("!=", storm::expressions::OperatorType::NotEqual); + } + }; + + // A parser used for recognizing the operators at the "equality" precedence level. + equalityOperatorStruct equalityOperator_; + + struct relationalOperatorStruct : qi::symbols { + relationalOperatorStruct() { + add + (">=", storm::expressions::OperatorType::GreaterOrEqual) + (">", storm::expressions::OperatorType::Greater) + ("<=", storm::expressions::OperatorType::LessOrEqual) + ("<", storm::expressions::OperatorType::Less); + } + }; + + // A parser used for recognizing the operators at the "relational" precedence level. + relationalOperatorStruct relationalOperator_; + + struct plusOperatorStruct : qi::symbols { + plusOperatorStruct() { + add + ("+", storm::expressions::OperatorType::Plus) + ("-", storm::expressions::OperatorType::Minus); + } + }; + + // A parser used for recognizing the operators at the "plus" precedence level. + plusOperatorStruct plusOperator_; + + struct multiplicationOperatorStruct : qi::symbols { + multiplicationOperatorStruct() { + add + ("*", storm::expressions::OperatorType::Times) + ("/", storm::expressions::OperatorType::Divide); + } + }; + + // A parser used for recognizing the operators at the "multiplication" precedence level. + multiplicationOperatorStruct multiplicationOperator_; + + struct powerOperatorStruct : qi::symbols { + powerOperatorStruct() { + add + ("^", storm::expressions::OperatorType::Power); + } + }; + + // A parser used for recognizing the operators at the "power" precedence level. + powerOperatorStruct powerOperator_; + + struct unaryOperatorStruct : qi::symbols { + unaryOperatorStruct() { + add + ("!", storm::expressions::OperatorType::Not) + ("-", storm::expressions::OperatorType::Minus); + } + }; + + // A parser used for recognizing the operators at the "unary" precedence level. + unaryOperatorStruct unaryOperator_; + + struct floorCeilOperatorStruct : qi::symbols { + floorCeilOperatorStruct() { + add + ("floor", storm::expressions::OperatorType::Floor) + ("ceil", storm::expressions::OperatorType::Ceil); + } + }; + + // A parser used for recognizing the operators at the "floor/ceil" precedence level. + floorCeilOperatorStruct floorCeilOperator_; + + struct minMaxOperatorStruct : qi::symbols { + minMaxOperatorStruct() { + add + ("min", storm::expressions::OperatorType::Min) + ("max", storm::expressions::OperatorType::Max); + } + }; + + // A parser used for recognizing the operators at the "min/max" precedence level. + minMaxOperatorStruct minMaxOperator_; + + struct trueFalseOperatorStruct : qi::symbols { + trueFalseOperatorStruct() { + add + ("true", storm::expressions::Expression::createTrue()) + ("false", storm::expressions::Expression::createFalse()); + } + }; + + // A parser used for recognizing the literals true and false. + trueFalseOperatorStruct trueFalse_; + // A flag that indicates whether expressions should actually be generated or just a syntax check shall be // performed. bool createExpressions; @@ -80,30 +201,18 @@ namespace storm { // Helper functions to create expressions. storm::expressions::Expression createIteExpression(storm::expressions::Expression e1, storm::expressions::Expression e2, storm::expressions::Expression e3) const; - storm::expressions::Expression createImpliesExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createOrExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createAndExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createGreaterExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createGreaterOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createLessExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createLessOrEqualExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createNotEqualsExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createPlusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMultExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createPowerExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createDivExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createNotExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createMinusExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createTrueExpression() const; - storm::expressions::Expression createFalseExpression() const; + storm::expressions::Expression createOrExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createAndExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createRelationalExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createEqualsExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createPlusExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createMultExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createPowerExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createUnaryExpression(boost::optional const& operatorType, storm::expressions::Expression const& e1) const; storm::expressions::Expression createDoubleLiteralExpression(double value, bool& pass) const; storm::expressions::Expression createIntegerLiteralExpression(int value) const; - storm::expressions::Expression createMinimumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createMaximumExpression(storm::expressions::Expression e1, storm::expressions::Expression e2) const; - storm::expressions::Expression createFloorExpression(storm::expressions::Expression e1) const; - storm::expressions::Expression createCeilExpression(storm::expressions::Expression e1) const; + storm::expressions::Expression createMinimumMaximumExpression(storm::expressions::Expression const& e1, storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e2) const; + storm::expressions::Expression createFloorCeilExpression(storm::expressions::OperatorType const& operatorType, storm::expressions::Expression const& e1) const; storm::expressions::Expression getIdentifierExpression(std::string const& identifier) const; bool isValidIdentifier(std::string const& identifier); diff --git a/src/parser/MarkovAutomatonSparseTransitionParser.cpp b/src/parser/MarkovAutomatonSparseTransitionParser.cpp index 21403c3be..6e02df99d 100644 --- a/src/parser/MarkovAutomatonSparseTransitionParser.cpp +++ b/src/parser/MarkovAutomatonSparseTransitionParser.cpp @@ -257,9 +257,6 @@ namespace storm { ++currentChoice; } - // Put a sentinel element at the end. - result.transitionMatrixBuilder.newRowGroup(currentChoice); - return result; } diff --git a/src/parser/PrismParser.cpp b/src/parser/PrismParser.cpp index b026880d3..5d73e9d2a 100644 --- a/src/parser/PrismParser.cpp +++ b/src/parser/PrismParser.cpp @@ -265,10 +265,12 @@ namespace storm { return storm::prism::Constant(storm::expressions::ExpressionReturnType::Double, newConstant, expression, this->getFilename()); } - storm::prism::Formula PrismParser::createFormula(std::string const& formulaName, storm::expressions::Expression expression) const { + storm::prism::Formula PrismParser::createFormula(std::string const& formulaName, storm::expressions::Expression expression) { if (!this->secondRun) { LOG_THROW(this->identifiers_.find(formulaName) == nullptr, storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Duplicate identifier '" << formulaName << "'."); this->identifiers_.add(formulaName, expression); + } else { + this->identifiers_.at(formulaName) = expression; } return storm::prism::Formula(formulaName, expression, this->getFilename()); } diff --git a/src/parser/PrismParser.h b/src/parser/PrismParser.h index bdf84af24..ac36883e9 100644 --- a/src/parser/PrismParser.h +++ b/src/parser/PrismParser.h @@ -225,7 +225,7 @@ namespace storm { storm::prism::Constant createDefinedBooleanConstant(std::string const& newConstant, storm::expressions::Expression expression) const; storm::prism::Constant createDefinedIntegerConstant(std::string const& newConstant, storm::expressions::Expression expression) const; storm::prism::Constant createDefinedDoubleConstant(std::string const& newConstant, storm::expressions::Expression expression) const; - storm::prism::Formula createFormula(std::string const& formulaName, storm::expressions::Expression expression) const; + storm::prism::Formula createFormula(std::string const& formulaName, storm::expressions::Expression expression); storm::prism::Label createLabel(std::string const& labelName, storm::expressions::Expression expression) const; storm::prism::RewardModel createRewardModel(std::string const& rewardModelName, std::vector const& stateRewards, std::vector const& transitionRewards) const; storm::prism::StateReward createStateReward(storm::expressions::Expression statePredicateExpression, storm::expressions::Expression rewardValueExpression) const; diff --git a/src/solver/SmtSolver.h b/src/solver/SmtSolver.h new file mode 100644 index 000000000..c34c8a440 --- /dev/null +++ b/src/solver/SmtSolver.h @@ -0,0 +1,190 @@ +#ifndef STORM_SOLVER_SMTSOLVER +#define STORM_SOLVER_SMTSOLVER + +#include + +#include "exceptions/IllegalArgumentValueException.h" +#include "exceptions/NotImplementedException.h" +#include "exceptions/IllegalArgumentTypeException.h" +#include "exceptions/IllegalFunctionCallException.h" +#include "exceptions/InvalidStateException.h" +#include "storage/expressions/Expressions.h" +#include "storage/expressions/SimpleValuation.h" + +#include +#include +#include +#include + +namespace storm { + namespace solver { + + /*! + * An interface that captures the functionality of an SMT solver. + */ + class SmtSolver { + public: + //! Option flags for smt solvers. + enum class Options { + ModelGeneration = 0x01, + UnsatCoreComputation = 0x02, + InterpolantComputation = 0x04 + }; + //! possible check results + enum class CheckResult { SAT, UNSAT, UNKNOWN }; + public: + /*! + * Constructs a new smt solver with the given options. + * + * @param options the options for the solver + * @throws storm::exceptions::IllegalArgumentValueException if an option is unsupported for the solver + */ + SmtSolver(Options options = Options::ModelGeneration) {}; + virtual ~SmtSolver() {}; + + SmtSolver(const SmtSolver&) = delete; + SmtSolver(const SmtSolver&&) {}; + + //! pushes a backtrackingpoint in the solver + virtual void push() = 0; + //! pops a backtrackingpoint in the solver + virtual void pop() = 0; + //! pops multiple backtrack points + //! @param n number of backtrackingpoint to pop + virtual void pop(uint_fast64_t n) = 0; + //! removes all assertions + virtual void reset() = 0; + + + //! assert an expression in the solver + //! @param e the asserted expression, the return type has to be bool + //! @throws IllegalArgumentTypeException if the return type of the expression is not bool + virtual void assertExpression(storm::expressions::Expression const& e) = 0; + //! assert a set of expressions in the solver + //! @param es the asserted expressions + //! @see assert(storm::expressions::Expression &e) + virtual void assertExpression(std::set const& es) { + for (storm::expressions::Expression e : es) { + this->assertExpression(e); + } + } + //! assert a set of expressions in the solver + //! @param es the asserted expressions + //! @see assert(storm::expressions::Expression &e) + /* std::hash unavailable for expressions + virtual void assertExpression(std::unordered_set &es) { + for (storm::expressions::Expression e : es) { + this->assertExpression(e); + } + }*/ + //! assert a set of expressions in the solver + //! @param es the asserted expressions + //! @see assert(storm::expressions::Expression &e) + virtual void assertExpression(std::initializer_list const& es) { + for (storm::expressions::Expression e : es) { + this->assertExpression(e); + } + } + + /*! + * check satisfiability of the conjunction of the currently asserted expressions + * + * @returns CheckResult::SAT if the conjunction of the asserted expressions is satisfiable, + * CheckResult::UNSAT if it is unsatisfiable and CheckResult::UNKNOWN if the solver + * could not determine satisfiability + */ + virtual CheckResult check() = 0; + //! check satisfiability of the conjunction of the currently asserted expressions and the provided assumptions + //! @param es the asserted expressions + //! @throws IllegalArgumentTypeException if the return type of one of the expressions is not bool + //! @see check() + virtual CheckResult checkWithAssumptions(std::set const& assumptions) = 0; + //! check satisfiability of the conjunction of the currently asserted expressions and the provided assumptions + //! @param es the asserted expressions + //! @throws IllegalArgumentTypeException if the return type of one of the expressions is not bool + //! @see check() + /* std::hash unavailable for expressions + virtual CheckResult checkWithAssumptions(std::unordered_set &assumptions) = 0; + */ + //! check satisfiability of the conjunction of the currently asserted expressions and the provided assumptions + //! @param es the asserted expressions + //! @throws IllegalArgumentTypeException if the return type of one of the expressions is not bool + //! @see check() + virtual CheckResult checkWithAssumptions(std::initializer_list assumptions) = 0; + + /*! + * Gets a model for the assertion stack (and possibly assumtions) for the last call to ::check or ::checkWithAssumptions if that call + * returned CheckResult::SAT. Otherwise an exception is thrown. + * @remark Note that this function may throw an exception if it is not called immediately after a call to::check or ::checkWithAssumptions + * that returned CheckResult::SAT depending on the implementation. + * @throws InvalidStateException if no model is available + * @throws IllegalFunctionCallException if model generation is not configured for this solver + * @throws NotImplementedException if model generation is not implemented with this solver class + */ + virtual storm::expressions::SimpleValuation getModel() { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support model generation."); + } + + /*! + * Performs all AllSat over the important atoms. All valuations of the important atoms such that the currently asserted formulas are satisfiable + * are returned from the function. + * + * @warning If infinitely many valuations exist, such that the currently asserted formulas are satisfiable, this function will never return! + * + * @param important A set of expressions over which to perform all sat. + * + * @returns the set of all valuations of the important atoms, such that the currently asserted formulas are satisfiable + * + * @throws IllegalFunctionCallException if model generation is not configured for this solver + * @throws NotImplementedException if model generation is not implemented with this solver class + */ + virtual std::vector allSat(std::vector const& important) { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support model generation."); + } + + /*! + * Performs all AllSat over the important atoms. Once a valuation of the important atoms such that the currently asserted formulas are satisfiable + * is found the callback is called with that valuation. + * + * @param important A set of expressions over which to perform all sat. + * @param callback A function to call for each found valuation. + * + * @returns the number of valuations of the important atoms, such that the currently asserted formulas are satisfiable that where found + * + * @throws IllegalFunctionCallException if model generation is not configured for this solver + * @throws NotImplementedException if model generation is not implemented with this solver class + */ + virtual uint_fast64_t allSat(std::vector const& important, std::function callback) { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support model generation."); + } + + /*! + * Retrieves the unsat core of the last call to check() + * + * @returns a subset of the asserted formulas s.t. this subset is unsat + * + * @throws InvalidStateException if no unsat core is available, i.e. the asserted formulas are consistent + * @throws IllegalFunctionCallException if unsat core generation is not configured for this solver + * @throws NotImplementedException if unsat core generation is not implemented with this solver class + */ + virtual std::vector getUnsatCore() { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support unsat core generation."); + } + + /*! + * Retrieves a subset of the assumptions from the last call to checkWithAssumptions(), s.t. the result is still unsatisfiable + * + * @returns a subset of the assumptions s.t. this subset of the assumptions results in unsat + * + * @throws InvalidStateException if no unsat assumptions is available, i.e. the asserted formulas are consistent + * @throws IllegalFunctionCallException if unsat assumptions generation is not configured for this solver + * @throws NotImplementedException if unsat assumptions generation is not implemented with this solver class + */ + virtual std::vector getUnsatAssumptions() { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support unsat core generation."); + } + }; + } +} + +#endif // STORM_SOLVER_SMTSOLVER diff --git a/src/solver/Z3SmtSolver.cpp b/src/solver/Z3SmtSolver.cpp new file mode 100644 index 000000000..45720c293 --- /dev/null +++ b/src/solver/Z3SmtSolver.cpp @@ -0,0 +1,266 @@ +#include "src/solver/Z3SmtSolver.h" + + +namespace storm { + namespace solver { + Z3SmtSolver::Z3SmtSolver(Options options) +#ifdef STORM_HAVE_Z3 + : m_context() + , m_solver(m_context) + , m_adapter(m_context, std::map()) + , lastCheckAssumptions(false) + , lastResult(CheckResult::UNKNOWN) +#endif + { + //intentionally left empty + } + Z3SmtSolver::~Z3SmtSolver() {}; + + void Z3SmtSolver::push() + { +#ifdef STORM_HAVE_Z3 + this->m_solver.push(); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + void Z3SmtSolver::pop() + { +#ifdef STORM_HAVE_Z3 + this->m_solver.pop(); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + void Z3SmtSolver::pop(uint_fast64_t n) + { +#ifdef STORM_HAVE_Z3 + this->m_solver.pop((unsigned int)n); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + void Z3SmtSolver::reset() + { +#ifdef STORM_HAVE_Z3 + this->m_solver.reset(); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + void Z3SmtSolver::assertExpression(storm::expressions::Expression const& e) + { +#ifdef STORM_HAVE_Z3 + this->m_solver.add(m_adapter.translateExpression(e, true)); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + SmtSolver::CheckResult Z3SmtSolver::check() + { +#ifdef STORM_HAVE_Z3 + lastCheckAssumptions = false; + switch (this->m_solver.check()) { + case z3::sat: + this->lastResult = SmtSolver::CheckResult::SAT; + break; + case z3::unsat: + this->lastResult = SmtSolver::CheckResult::UNSAT; + break; + default: + this->lastResult = SmtSolver::CheckResult::UNKNOWN; + break; + } + return this->lastResult; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + SmtSolver::CheckResult Z3SmtSolver::checkWithAssumptions(std::set const& assumptions) + { +#ifdef STORM_HAVE_Z3 + lastCheckAssumptions = true; + z3::expr_vector z3Assumptions(this->m_context); + + for (storm::expressions::Expression assumption : assumptions) { + z3Assumptions.push_back(this->m_adapter.translateExpression(assumption)); + } + + switch (this->m_solver.check(z3Assumptions)) { + case z3::sat: + this->lastResult = SmtSolver::CheckResult::SAT; + break; + case z3::unsat: + this->lastResult = SmtSolver::CheckResult::UNSAT; + break; + default: + this->lastResult = SmtSolver::CheckResult::UNKNOWN; + break; + } + return this->lastResult; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + SmtSolver::CheckResult Z3SmtSolver::checkWithAssumptions(std::initializer_list assumptions) + { +#ifdef STORM_HAVE_Z3 + lastCheckAssumptions = true; + z3::expr_vector z3Assumptions(this->m_context); + + for (storm::expressions::Expression assumption : assumptions) { + z3Assumptions.push_back(this->m_adapter.translateExpression(assumption)); + } + + switch (this->m_solver.check(z3Assumptions)) { + case z3::sat: + this->lastResult = SmtSolver::CheckResult::SAT; + break; + case z3::unsat: + this->lastResult = SmtSolver::CheckResult::UNSAT; + break; + default: + this->lastResult = SmtSolver::CheckResult::UNKNOWN; + break; + } + return this->lastResult; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + storm::expressions::SimpleValuation Z3SmtSolver::getModel() + { +#ifdef STORM_HAVE_Z3 + + LOG_THROW(this->lastResult == SmtSolver::CheckResult::SAT, storm::exceptions::InvalidStateException, "Requested Model but last check result was not SAT."); + + return this->z3ModelToStorm(this->m_solver.get_model()); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + +#ifdef STORM_HAVE_Z3 + storm::expressions::SimpleValuation Z3SmtSolver::z3ModelToStorm(z3::model m) { + storm::expressions::SimpleValuation stormModel; + + for (unsigned i = 0; i < m.num_consts(); ++i) { + z3::func_decl var_i = m.get_const_decl(i); + storm::expressions::Expression var_i_interp = this->m_adapter.translateExpression(m.get_const_interp(var_i)); + + switch (var_i_interp.getReturnType()) { + case storm::expressions::ExpressionReturnType::Bool: + stormModel.addBooleanIdentifier(var_i.name().str(), var_i_interp.evaluateAsBool()); + break; + case storm::expressions::ExpressionReturnType::Int: + stormModel.addIntegerIdentifier(var_i.name().str(), var_i_interp.evaluateAsInt()); + break; + case storm::expressions::ExpressionReturnType::Double: + stormModel.addDoubleIdentifier(var_i.name().str(), var_i_interp.evaluateAsDouble()); + break; + default: + LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Variable interpretation in model is not of type bool, int or double.") + break; + } + + } + + return stormModel; + } +#endif + + std::vector Z3SmtSolver::allSat(std::vector const& important) + { +#ifdef STORM_HAVE_Z3 + + std::vector valuations; + + this->allSat(important, [&valuations](storm::expressions::SimpleValuation& valuation) -> bool {valuations.push_back(valuation); return true; }); + + return valuations; + +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + uint_fast64_t Z3SmtSolver::allSat(std::vector const& important, std::function callback) + { +#ifdef STORM_HAVE_Z3 + for (storm::expressions::Expression e : important) { + if (!e.isVariable()) { + throw storm::exceptions::InvalidArgumentException() << "The important expressions for AllSat must be atoms, i.e. variable expressions."; + } + } + + uint_fast64_t numModels = 0; + bool proceed = true; + + this->push(); + + while (proceed && this->check() == CheckResult::SAT) { + ++numModels; + z3::model m = this->m_solver.get_model(); + + z3::expr modelExpr = this->m_context.bool_val(true); + storm::expressions::SimpleValuation valuation; + + for (storm::expressions::Expression importantAtom : important) { + z3::expr z3ImportantAtom = this->m_adapter.translateExpression(importantAtom); + z3::expr z3ImportantAtomValuation = m.eval(z3ImportantAtom, true); + modelExpr = modelExpr && (z3ImportantAtom == z3ImportantAtomValuation); + if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Bool) { + valuation.addBooleanIdentifier(importantAtom.getIdentifier(), this->m_adapter.translateExpression(z3ImportantAtomValuation).evaluateAsBool()); + } else if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Int) { + valuation.addIntegerIdentifier(importantAtom.getIdentifier(), this->m_adapter.translateExpression(z3ImportantAtomValuation).evaluateAsInt()); + } else if (importantAtom.getReturnType() == storm::expressions::ExpressionReturnType::Double) { + valuation.addDoubleIdentifier(importantAtom.getIdentifier(), this->m_adapter.translateExpression(z3ImportantAtomValuation).evaluateAsDouble()); + } else { + throw storm::exceptions::InvalidTypeException() << "Important atom has invalid type"; + } + } + + proceed = callback(valuation); + + this->m_solver.add(!modelExpr); + } + + this->pop(); + + return numModels; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + std::vector Z3SmtSolver::getUnsatAssumptions() { +#ifdef STORM_HAVE_Z3 + if (lastResult != SmtSolver::CheckResult::UNSAT) { + throw storm::exceptions::InvalidStateException() << "Unsat Assumptions was called but last state is not unsat."; + } + if (!lastCheckAssumptions) { + throw storm::exceptions::InvalidStateException() << "Unsat Assumptions was called but last check had no assumptions."; + } + + z3::expr_vector z3UnsatAssumptions = this->m_solver.unsat_core(); + + std::vector unsatAssumptions; + + for (unsigned int i = 0; i < z3UnsatAssumptions.size(); ++i) { + unsatAssumptions.push_back(this->m_adapter.translateExpression(z3UnsatAssumptions[i])); + } + + return unsatAssumptions; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + } +} \ No newline at end of file diff --git a/src/solver/Z3SmtSolver.h b/src/solver/Z3SmtSolver.h new file mode 100644 index 000000000..0f2059443 --- /dev/null +++ b/src/solver/Z3SmtSolver.h @@ -0,0 +1,61 @@ +#ifndef STORM_SOLVER_Z3SMTSOLVER +#define STORM_SOLVER_Z3SMTSOLVER + +#include "storm-config.h" +#include "src/solver/SmtSolver.h" +#include "src/adapters/Z3ExpressionAdapter.h" + +#ifdef STORM_HAVE_Z3 +#include "z3++.h" +#include "z3.h" +#endif + +namespace storm { + namespace solver { + class Z3SmtSolver : public SmtSolver { + public: + Z3SmtSolver(Options options = Options::ModelGeneration); + virtual ~Z3SmtSolver(); + + virtual void push() override; + + virtual void pop() override; + + virtual void pop(uint_fast64_t n) override; + + virtual void reset() override; + + virtual void assertExpression(storm::expressions::Expression const& e) override; + + virtual CheckResult check() override; + + virtual CheckResult checkWithAssumptions(std::set const& assumptions) override; + + virtual CheckResult checkWithAssumptions(std::initializer_list assumptions) override; + + virtual storm::expressions::SimpleValuation getModel() override; + + virtual std::vector allSat(std::vector const& important) override; + + virtual uint_fast64_t allSat(std::vector const& important, std::function callback) override; + + virtual std::vector getUnsatAssumptions() override; + + protected: +#ifdef STORM_HAVE_Z3 + virtual storm::expressions::SimpleValuation z3ModelToStorm(z3::model m); +#endif + private: + +#ifdef STORM_HAVE_Z3 + z3::context m_context; + z3::solver m_solver; + storm::adapters::Z3ExpressionAdapter m_adapter; + + bool lastCheckAssumptions; + CheckResult lastResult; +#endif + }; + } +} +#endif // STORM_SOLVER_Z3SMTSOLVER \ No newline at end of file diff --git a/src/storage/SparseMatrix.cpp b/src/storage/SparseMatrix.cpp index c300e6858..4ffcbe991 100644 --- a/src/storage/SparseMatrix.cpp +++ b/src/storage/SparseMatrix.cpp @@ -787,8 +787,8 @@ namespace storm { const_iterator ite; std::vector::const_iterator rowIterator = this->rowIndications.begin() + startRow; std::vector::const_iterator rowIteratorEnd = this->rowIndications.begin() + endRow; - std::vector::iterator resultIterator = result.begin() + startRow; - std::vector::iterator resultIteratorEnd = result.begin() + endRow; + typename std::vector::iterator resultIterator = result.begin() + startRow; + typename std::vector::iterator resultIteratorEnd = result.begin() + endRow; for (; resultIterator != resultIteratorEnd; ++rowIterator, ++resultIterator) { *resultIterator = storm::utility::constantZero(); diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 9524fcdbb..8afe11b91 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "src/storage/expressions/ExpressionReturnType.h" @@ -147,7 +148,14 @@ namespace storm { * @return The set of all variables that appear in the expression. */ virtual std::set getVariables() const = 0; - + + /*! + * Retrieves the mapping of all variables that appear in the expression to their return type. + * + * @return The mapping of all variables that appear in the expression to their return type. + */ + virtual std::map getVariablesAndTypes() const = 0; + /*! * Simplifies the expression according to some simple rules. * diff --git a/src/storage/expressions/BinaryExpression.cpp b/src/storage/expressions/BinaryExpression.cpp index 4d3c4be81..bdb7fb59f 100644 --- a/src/storage/expressions/BinaryExpression.cpp +++ b/src/storage/expressions/BinaryExpression.cpp @@ -15,14 +15,21 @@ namespace storm { bool BinaryExpression::containsVariables() const { return this->getFirstOperand()->containsVariables() || this->getSecondOperand()->containsVariables(); - } - - std::set BinaryExpression::getVariables() const { - std::set firstVariableSet = this->getFirstOperand()->getVariables(); - std::set secondVariableSet = this->getSecondOperand()->getVariables(); - firstVariableSet.insert(secondVariableSet.begin(), secondVariableSet.end()); - return firstVariableSet; - } + } + + std::set BinaryExpression::getVariables() const { + std::set firstVariableSet = this->getFirstOperand()->getVariables(); + std::set secondVariableSet = this->getSecondOperand()->getVariables(); + firstVariableSet.insert(secondVariableSet.begin(), secondVariableSet.end()); + return firstVariableSet; + } + + std::map BinaryExpression::getVariablesAndTypes() const { + std::map firstVariableSet = this->getFirstOperand()->getVariablesAndTypes(); + std::map secondVariableSet = this->getSecondOperand()->getVariablesAndTypes(); + firstVariableSet.insert(secondVariableSet.begin(), secondVariableSet.end()); + return firstVariableSet; + } std::shared_ptr const& BinaryExpression::getFirstOperand() const { return this->firstOperand; diff --git a/src/storage/expressions/BinaryExpression.h b/src/storage/expressions/BinaryExpression.h index 75c9283cb..3f9be4d06 100644 --- a/src/storage/expressions/BinaryExpression.h +++ b/src/storage/expressions/BinaryExpression.h @@ -33,8 +33,9 @@ namespace storm { virtual bool isFunctionApplication() const override; virtual bool containsVariables() const override; virtual uint_fast64_t getArity() const override; - virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; - virtual std::set getVariables() const override; + virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; /*! * Retrieves the first operand of the expression. diff --git a/src/storage/expressions/BooleanLiteralExpression.cpp b/src/storage/expressions/BooleanLiteralExpression.cpp index 5112ec8a9..d510c6f45 100644 --- a/src/storage/expressions/BooleanLiteralExpression.cpp +++ b/src/storage/expressions/BooleanLiteralExpression.cpp @@ -24,7 +24,11 @@ namespace storm { std::set BooleanLiteralExpression::getVariables() const { return std::set(); - } + } + + std::map BooleanLiteralExpression::getVariablesAndTypes() const { + return std::map(); + } std::shared_ptr BooleanLiteralExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/BooleanLiteralExpression.h b/src/storage/expressions/BooleanLiteralExpression.h index fc299e60a..57c19677c 100644 --- a/src/storage/expressions/BooleanLiteralExpression.h +++ b/src/storage/expressions/BooleanLiteralExpression.h @@ -29,7 +29,8 @@ namespace storm { virtual bool isLiteral() const override; virtual bool isTrue() const override; virtual bool isFalse() const override; - virtual std::set getVariables() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; diff --git a/src/storage/expressions/DoubleLiteralExpression.cpp b/src/storage/expressions/DoubleLiteralExpression.cpp index b780a8862..00471b533 100644 --- a/src/storage/expressions/DoubleLiteralExpression.cpp +++ b/src/storage/expressions/DoubleLiteralExpression.cpp @@ -16,7 +16,11 @@ namespace storm { std::set DoubleLiteralExpression::getVariables() const { return std::set(); - } + } + + std::map DoubleLiteralExpression::getVariablesAndTypes() const { + return std::map(); + } std::shared_ptr DoubleLiteralExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/DoubleLiteralExpression.h b/src/storage/expressions/DoubleLiteralExpression.h index ad22ee3be..515326b8b 100644 --- a/src/storage/expressions/DoubleLiteralExpression.h +++ b/src/storage/expressions/DoubleLiteralExpression.h @@ -27,7 +27,8 @@ namespace storm { // Override base class methods. virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual bool isLiteral() const override; - virtual std::set getVariables() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 2c247f796..c61e93b9d 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -95,6 +95,21 @@ namespace storm { bool Expression::isFalse() const { return this->getBaseExpression().isFalse(); } + + std::set Expression::getVariables() const { + return this->getBaseExpression().getVariables(); + } + + std::map Expression::getVariablesAndTypes(bool validate) const { + if (validate) { + std::map result = this->getBaseExpression().getVariablesAndTypes(); + this->check(result); + return result; + } + else { + return this->getBaseExpression().getVariablesAndTypes(); + } + } bool Expression::isRelationalExpression() const { if (!this->isFunctionApplication()) { @@ -110,10 +125,6 @@ namespace storm { return LinearityCheckVisitor().check(*this); } - std::set Expression::getVariables() const { - return this->getBaseExpression().getVariables(); - } - BaseExpression const& Expression::getBaseExpression() const { return *this->expressionPtr; } diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index 08c758f27..504e3d96e 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -253,6 +253,19 @@ namespace storm { */ std::set getVariables() const; + /*! + * Retrieves the mapping of all variables that appear in the expression to their return type. + * + * @param validate If this parameter is true, check() is called with the returnvalue before + * it is returned. + * + * @throws storm::exceptions::InvalidTypeException If a variables with the same name but different + * types occur somewhere withing the expression. + * + * @return The mapping of all variables that appear in the expression to their return type. + */ + std::map getVariablesAndTypes(bool validate = true) const; + /*! * Retrieves the base expression underlying this expression object. Note that prior to calling this, the * expression object must be properly initialized. diff --git a/src/storage/expressions/IfThenElseExpression.cpp b/src/storage/expressions/IfThenElseExpression.cpp index bbd0b28d2..50c71af75 100644 --- a/src/storage/expressions/IfThenElseExpression.cpp +++ b/src/storage/expressions/IfThenElseExpression.cpp @@ -61,16 +61,25 @@ namespace storm { } else { return this->elseExpression->evaluateAsDouble(valuation); } - } - - std::set IfThenElseExpression::getVariables() const { - std::set result = this->condition->getVariables(); - std::set tmp = this->thenExpression->getVariables(); - result.insert(tmp.begin(), tmp.end()); - tmp = this->elseExpression->getVariables(); - result.insert(tmp.begin(), tmp.end()); - return result; - } + } + + std::set IfThenElseExpression::getVariables() const { + std::set result = this->condition->getVariables(); + std::set tmp = this->thenExpression->getVariables(); + result.insert(tmp.begin(), tmp.end()); + tmp = this->elseExpression->getVariables(); + result.insert(tmp.begin(), tmp.end()); + return result; + } + + std::map IfThenElseExpression::getVariablesAndTypes() const { + std::map result = this->condition->getVariablesAndTypes(); + std::map tmp = this->thenExpression->getVariablesAndTypes(); + result.insert(tmp.begin(), tmp.end()); + tmp = this->elseExpression->getVariablesAndTypes(); + result.insert(tmp.begin(), tmp.end()); + return result; + } std::shared_ptr IfThenElseExpression::simplify() const { std::shared_ptr conditionSimplified; diff --git a/src/storage/expressions/IfThenElseExpression.h b/src/storage/expressions/IfThenElseExpression.h index 425633850..63619d852 100644 --- a/src/storage/expressions/IfThenElseExpression.h +++ b/src/storage/expressions/IfThenElseExpression.h @@ -35,7 +35,8 @@ namespace storm { virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override; virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override; virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; - virtual std::set getVariables() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; diff --git a/src/storage/expressions/IntegerLiteralExpression.cpp b/src/storage/expressions/IntegerLiteralExpression.cpp index 47341a771..b05ab5a67 100644 --- a/src/storage/expressions/IntegerLiteralExpression.cpp +++ b/src/storage/expressions/IntegerLiteralExpression.cpp @@ -16,11 +16,15 @@ namespace storm { bool IntegerLiteralExpression::isLiteral() const { return true; - } - - std::set IntegerLiteralExpression::getVariables() const { - return std::set(); - } + } + + std::set IntegerLiteralExpression::getVariables() const { + return std::set(); + } + + std::map IntegerLiteralExpression::getVariablesAndTypes() const { + return std::map(); + } std::shared_ptr IntegerLiteralExpression::simplify() const { return this->shared_from_this(); diff --git a/src/storage/expressions/IntegerLiteralExpression.h b/src/storage/expressions/IntegerLiteralExpression.h index 4a9e8882f..5d6c731a5 100644 --- a/src/storage/expressions/IntegerLiteralExpression.h +++ b/src/storage/expressions/IntegerLiteralExpression.h @@ -28,7 +28,8 @@ namespace storm { virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override; virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual bool isLiteral() const override; - virtual std::set getVariables() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; diff --git a/src/storage/expressions/OperatorType.cpp b/src/storage/expressions/OperatorType.cpp new file mode 100644 index 000000000..a4117d97c --- /dev/null +++ b/src/storage/expressions/OperatorType.cpp @@ -0,0 +1,33 @@ +#include "src/storage/expressions/OperatorType.h" + +namespace storm { + namespace expressions { + std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType) { + switch (operatorType) { + case OperatorType::And: stream << "&"; break; + case OperatorType::Or: stream << "|"; break; + case OperatorType::Xor: stream << "!="; break; + case OperatorType::Implies: stream << "=>"; break; + case OperatorType::Iff: stream << "<=>"; break; + case OperatorType::Plus: stream << "+"; break; + case OperatorType::Minus: stream << "-"; break; + case OperatorType::Times: stream << "*"; break; + case OperatorType::Divide: stream << "/"; break; + case OperatorType::Min: stream << "min"; break; + case OperatorType::Max: stream << "max"; break; + case OperatorType::Power: stream << "^"; break; + case OperatorType::Equal: stream << "="; break; + case OperatorType::NotEqual: stream << "!="; break; + case OperatorType::Less: stream << "<"; break; + case OperatorType::LessOrEqual: stream << "<="; break; + case OperatorType::Greater: stream << ">"; break; + case OperatorType::GreaterOrEqual: stream << ">="; break; + case OperatorType::Not: stream << "!"; break; + case OperatorType::Floor: stream << "floor"; break; + case OperatorType::Ceil: stream << "ceil"; break; + case OperatorType::Ite: stream << "ite"; break; + } + return stream; + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/OperatorType.h b/src/storage/expressions/OperatorType.h index 8968cf105..5908bf789 100644 --- a/src/storage/expressions/OperatorType.h +++ b/src/storage/expressions/OperatorType.h @@ -1,6 +1,8 @@ #ifndef STORM_STORAGE_EXPRESSIONS_OPERATORTYPE_H_ #define STORM_STORAGE_EXPRESSIONS_OPERATORTYPE_H_ +#include + namespace storm { namespace expressions { // An enum representing all possible operator types. @@ -28,6 +30,8 @@ namespace storm { Ceil, Ite }; + + std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType); } } diff --git a/src/storage/expressions/SimpleValuation.cpp b/src/storage/expressions/SimpleValuation.cpp index 923bbe84b..f50b1ef49 100644 --- a/src/storage/expressions/SimpleValuation.cpp +++ b/src/storage/expressions/SimpleValuation.cpp @@ -15,17 +15,17 @@ namespace storm { void SimpleValuation::addBooleanIdentifier(std::string const& name, bool initialValue) { LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); + this->identifierToValueMap.emplace(name, initialValue); } void SimpleValuation::addIntegerIdentifier(std::string const& name, int_fast64_t initialValue) { LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); + this->identifierToValueMap.emplace(name, initialValue); } void SimpleValuation::addDoubleIdentifier(std::string const& name, double initialValue) { LOG_THROW(this->identifierToValueMap.find(name) == this->identifierToValueMap.end(), storm::exceptions::InvalidArgumentException, "Identifier '" << name << "' already registered."); - this->identifierToValueMap.emplace(name, initialValue); + this->identifierToValueMap.emplace(name, initialValue); } void SimpleValuation::setBooleanValue(std::string const& name, bool value) { diff --git a/src/storage/expressions/TypeCheckVisitor.cpp b/src/storage/expressions/TypeCheckVisitor.cpp index 5ab80e141..521453042 100644 --- a/src/storage/expressions/TypeCheckVisitor.cpp +++ b/src/storage/expressions/TypeCheckVisitor.cpp @@ -43,8 +43,8 @@ namespace storm { template void TypeCheckVisitor::visit(VariableExpression const* expression) { - auto identifierTypePair = this->identifierToTypeMap.find(expression->getVariableName()); - LOG_THROW(identifierTypePair != this->identifierToTypeMap.end(), storm::exceptions::InvalidArgumentException, "No type available for identifier '" << expression->getVariableName() << "'."); + auto identifierTypePair = this->identifierToTypeMap.find(expression->getVariableName()); + LOG_THROW(identifierTypePair != this->identifierToTypeMap.end(), storm::exceptions::InvalidArgumentException, "No type available for identifier '" << expression->getVariableName() << "'."); LOG_THROW(identifierTypePair->second == expression->getReturnType(), storm::exceptions::InvalidTypeException, "Type mismatch for variable '" << expression->getVariableName() << "': expected '" << identifierTypePair->first << "', but found '" << expression->getReturnType() << "'."); } diff --git a/src/storage/expressions/UnaryExpression.cpp b/src/storage/expressions/UnaryExpression.cpp index 5d8262766..169466003 100644 --- a/src/storage/expressions/UnaryExpression.cpp +++ b/src/storage/expressions/UnaryExpression.cpp @@ -16,10 +16,14 @@ namespace storm { bool UnaryExpression::containsVariables() const { return this->getOperand()->containsVariables(); } - - std::set UnaryExpression::getVariables() const { - return this->getOperand()->getVariables(); - } + + std::set UnaryExpression::getVariables() const { + return this->getOperand()->getVariables(); + } + + std::map UnaryExpression::getVariablesAndTypes() const { + return this->getOperand()->getVariablesAndTypes(); + } std::shared_ptr const& UnaryExpression::getOperand() const { return this->operand; diff --git a/src/storage/expressions/UnaryExpression.h b/src/storage/expressions/UnaryExpression.h index a387ad8d5..5473d3122 100644 --- a/src/storage/expressions/UnaryExpression.h +++ b/src/storage/expressions/UnaryExpression.h @@ -30,7 +30,8 @@ namespace storm { virtual bool containsVariables() const override; virtual uint_fast64_t getArity() const override; virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; - virtual std::set getVariables() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; /*! * Retrieves the operand of the unary expression. diff --git a/src/storage/expressions/VariableExpression.cpp b/src/storage/expressions/VariableExpression.cpp index 2062c58ee..a9f714569 100644 --- a/src/storage/expressions/VariableExpression.cpp +++ b/src/storage/expressions/VariableExpression.cpp @@ -48,11 +48,19 @@ namespace storm { bool VariableExpression::containsVariables() const { return true; } + + bool VariableExpression::isVariable() const { + return true; + } std::set VariableExpression::getVariables() const { return {this->getVariableName()}; } + std::map VariableExpression::getVariablesAndTypes() const { + return{ std::make_pair(this->getVariableName(), this->getReturnType()) }; + } + std::shared_ptr VariableExpression::simplify() const { return this->shared_from_this(); } diff --git a/src/storage/expressions/VariableExpression.h b/src/storage/expressions/VariableExpression.h index 6eff8c794..dda150495 100644 --- a/src/storage/expressions/VariableExpression.h +++ b/src/storage/expressions/VariableExpression.h @@ -31,7 +31,9 @@ namespace storm { virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual std::string const& getIdentifier() const override; virtual bool containsVariables() const override; - virtual std::set getVariables() const override; + virtual bool isVariable() const override; + virtual std::set getVariables() const override; + virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; diff --git a/src/storm.cpp b/src/storm.cpp index a41552d46..8e3ade79e 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -38,7 +38,7 @@ #include "src/solver/GmmxxNondeterministicLinearEquationSolver.h" #include "src/solver/GurobiLpSolver.h" #include "src/counterexamples/MILPMinimalLabelSetGenerator.h" -// #include "src/counterexamples/SMTMinimalCommandSetGenerator.h" +#include "src/counterexamples/SMTMinimalCommandSetGenerator.h" #include "src/counterexamples/PathBasedSubsystemGenerator.h" #include "src/parser/AutoParser.h" #include "src/parser/MarkovAutomatonParser.h" @@ -619,7 +619,7 @@ int main(const int argc, const char* argv[]) { if (useMILP) { storm::counterexamples::MILPMinimalLabelSetGenerator::computeCounterexample(program, *mdp, formulaPtr); } else { - // storm::counterexamples::SMTMinimalCommandSetGenerator::computeCounterexample(program, constants, *mdp, formulaPtr); + storm::counterexamples::SMTMinimalCommandSetGenerator::computeCounterexample(program, constants, *mdp, formulaPtr); } // Once we are done with the formula, delete it. @@ -628,7 +628,7 @@ int main(const int argc, const char* argv[]) { // MinCMD Time Measurement, End std::chrono::high_resolution_clock::time_point minCmdEnd = std::chrono::high_resolution_clock::now(); - std::cout << "Minimal command Counterexample generation took " << std::chrono::duration_cast(minCmdStart - minCmdEnd).count() << " milliseconds." << std::endl; + std::cout << "Minimal command Counterexample generation took " << std::chrono::duration_cast(minCmdEnd - minCmdStart).count() << " milliseconds." << std::endl; } else if (s->isSet("prctl")) { // Depending on the model type, the appropriate model checking procedure is chosen. storm::modelchecker::prctl::AbstractModelChecker* modelchecker = nullptr; diff --git a/src/utility/PrismUtility.h b/src/utility/PrismUtility.h index 9ed15411a..de8bcb2c7 100644 --- a/src/utility/PrismUtility.h +++ b/src/utility/PrismUtility.h @@ -188,6 +188,62 @@ namespace storm { auto& labeledEntry = choice.getOrAddEntry(state); labeledEntry.addValue(probability, labels); } + + static std::map parseConstantDefinitionString(storm::prism::Program const& program, std::string const& constantDefinitionString) { + std::map constantDefinitions; + std::set definedConstants; + + if (!constantDefinitionString.empty()) { + // Parse the string that defines the undefined constants of the model and make sure that it contains exactly + // one value for each undefined constant of the model. + std::vector definitions; + boost::split(definitions, constantDefinitionString, boost::is_any_of(",")); + for (auto& definition : definitions) { + boost::trim(definition); + + // Check whether the token could be a legal constant definition. + uint_fast64_t positionOfAssignmentOperator = definition.find('='); + if (positionOfAssignmentOperator == std::string::npos) { + throw storm::exceptions::InvalidArgumentException() << "Illegal constant definition string: syntax error."; + } + + // Now extract the variable name and the value from the string. + std::string constantName = definition.substr(0, positionOfAssignmentOperator); + boost::trim(constantName); + std::string value = definition.substr(positionOfAssignmentOperator + 1); + boost::trim(value); + + // Check whether the constant is a legal undefined constant of the program and if so, of what type it is. + if (program.hasConstant(constantName)) { + // Get the actual constant and check whether it's in fact undefined. + auto const& constant = program.getConstant(constantName); + LOG_THROW(!constant.isDefined(), storm::exceptions::InvalidArgumentException, "Illegally trying to define already defined constant '" << constantName <<"'."); + LOG_THROW(definedConstants.find(constantName) == definedConstants.end(), storm::exceptions::InvalidArgumentException, "Illegally trying to define constant '" << constantName <<"' twice."); + definedConstants.insert(constantName); + + if (constant.getType() == storm::expressions::ExpressionReturnType::Bool) { + if (value == "true") { + constantDefinitions[constantName] = storm::expressions::Expression::createTrue(); + } else if (value == "false") { + constantDefinitions[constantName] = storm::expressions::Expression::createFalse(); + } else { + throw storm::exceptions::InvalidArgumentException() << "Illegal value for boolean constant: " << value << "."; + } + } else if (constant.getType() == storm::expressions::ExpressionReturnType::Int) { + int_fast64_t integerValue = std::stoi(value); + constantDefinitions[constantName] = storm::expressions::Expression::createIntegerLiteral(integerValue); + } else if (constant.getType() == storm::expressions::ExpressionReturnType::Double) { + double doubleValue = std::stod(value); + constantDefinitions[constantName] = storm::expressions::Expression::createDoubleLiteral(doubleValue); + } + } else { + throw storm::exceptions::InvalidArgumentException() << "Illegal constant definition string: unknown undefined constant " << constantName << "."; + } + } + } + + return constantDefinitions; + } } // namespace prism } // namespace utility } // namespace storm diff --git a/src/utility/vector.h b/src/utility/vector.h index 6dca02082..79fd28046 100644 --- a/src/utility/vector.h +++ b/src/utility/vector.h @@ -237,7 +237,7 @@ namespace storm { for (; targetIt != targetIte; ++targetIt, ++rowGroupingIt) { *targetIt = *sourceIt; ++sourceIt; - localChoice = 0; + localChoice = 1; if (choices != nullptr) { *choiceIt = 0; } @@ -271,7 +271,7 @@ namespace storm { for (; targetIt != targetIte; ++targetIt, ++rowGroupingIt) { *targetIt = *sourceIt; ++sourceIt; - localChoice = 0; + localChoice = 1; if (choices != nullptr) { *choiceIt = 0; } diff --git a/test/functional/adapter/Z3ExpressionAdapterTest.cpp b/test/functional/adapter/Z3ExpressionAdapterTest.cpp new file mode 100644 index 000000000..b92794f87 --- /dev/null +++ b/test/functional/adapter/Z3ExpressionAdapterTest.cpp @@ -0,0 +1,177 @@ +#include "gtest/gtest.h" +#include "storm-config.h" + +#ifdef STORM_HAVE_Z3 +#include "z3++.h" +#include "src/adapters/Z3ExpressionAdapter.h" +#include "src/settings/Settings.h" + +TEST(Z3ExpressionAdapter, StormToZ3Basic) { + z3::context ctx; + z3::solver s(ctx); + z3::expr conjecture = ctx.bool_val(false); + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + storm::expressions::Expression exprTrue = storm::expressions::Expression::createTrue(); + z3::expr z3True = ctx.bool_val(true); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprTrue), z3True))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); + + storm::expressions::Expression exprFalse = storm::expressions::Expression::createFalse(); + z3::expr z3False = ctx.bool_val(false); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprFalse), z3False))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); + + storm::expressions::Expression exprConjunction = (storm::expressions::Expression::createBooleanVariable("x") && storm::expressions::Expression::createBooleanVariable("y")); + z3::expr z3Conjunction = (ctx.bool_const("x") && ctx.bool_const("y")); + ASSERT_THROW( adapter.translateExpression(exprConjunction, false), std::out_of_range ); //variables not yet created in adapter + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprConjunction, true), z3Conjunction))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); + + storm::expressions::Expression exprNor = !(storm::expressions::Expression::createBooleanVariable("x") || storm::expressions::Expression::createBooleanVariable("y")); + z3::expr z3Nor = !(ctx.bool_const("x") || ctx.bool_const("y")); + ASSERT_NO_THROW(adapter.translateExpression(exprNor, false)); //variables already created in adapter + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprNor, true), z3Nor))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); +} + +TEST(Z3ExpressionAdapter, StormToZ3Integer) { + z3::context ctx; + z3::solver s(ctx); + z3::expr conjecture = ctx.bool_val(false); + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + storm::expressions::Expression exprAdd = (storm::expressions::Expression::createIntegerVariable("x") + storm::expressions::Expression::createIntegerVariable("y") < -storm::expressions::Expression::createIntegerVariable("y")); + z3::expr z3Add = (ctx.int_const("x") + ctx.int_const("y") < -ctx.int_const("y")); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprAdd, true), z3Add))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); + + storm::expressions::Expression exprMult = !(storm::expressions::Expression::createIntegerVariable("x") * storm::expressions::Expression::createIntegerVariable("y") == storm::expressions::Expression::createIntegerVariable("y")); + z3::expr z3Mult = !(ctx.int_const("x") * ctx.int_const("y") == ctx.int_const("y")); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprMult, true), z3Mult))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); +} + +TEST(Z3ExpressionAdapter, StormToZ3Real) { + z3::context ctx; + z3::solver s(ctx); + z3::expr conjecture = ctx.bool_val(false); + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + storm::expressions::Expression exprAdd = (storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") < -storm::expressions::Expression::createDoubleVariable("y")); + z3::expr z3Add = (ctx.real_const("x") + ctx.real_const("y") < -ctx.real_const("y")); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprAdd, true), z3Add))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); + + storm::expressions::Expression exprMult = !(storm::expressions::Expression::createDoubleVariable("x") * storm::expressions::Expression::createDoubleVariable("y") == storm::expressions::Expression::createDoubleVariable("y")); + z3::expr z3Mult = !(ctx.real_const("x") * ctx.real_const("y") == ctx.real_const("y")); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprMult, true), z3Mult))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); + s.reset(); +} + +TEST(Z3ExpressionAdapter, StormToZ3TypeErrors) { + z3::context ctx; + z3::solver s(ctx); + z3::expr conjecture = ctx.bool_val(false); + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + storm::expressions::Expression exprFail1 = (storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createIntegerVariable("y") < -storm::expressions::Expression::createDoubleVariable("y")); + ASSERT_THROW(conjecture = adapter.translateExpression(exprFail1, true), storm::exceptions::InvalidTypeException); +} + +TEST(Z3ExpressionAdapter, StormToZ3FloorCeil) { + z3::context ctx; + z3::solver s(ctx); + z3::expr conjecture = ctx.bool_val(false); + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + storm::expressions::Expression exprFloor = ((storm::expressions::Expression::createDoubleVariable("d").floor()) == storm::expressions::Expression::createIntegerVariable("i") && storm::expressions::Expression::createDoubleVariable("d") > storm::expressions::Expression::createDoubleLiteral(4.1) && storm::expressions::Expression::createDoubleVariable("d") < storm::expressions::Expression::createDoubleLiteral(4.991)); + z3::expr z3Floor = ctx.int_val(4) == ctx.int_const("i"); + + //try { adapter.translateExpression(exprFloor, true); } + //catch (std::exception &e) { std::cout << e.what() << std::endl; } + + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprFloor, true), z3Floor))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::sat); //it is NOT logical equivalent + s.reset(); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_implies(ctx, adapter.translateExpression(exprFloor, true), z3Floor))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); //it is NOT logical equivalent + s.reset(); + + + storm::expressions::Expression exprCeil = ((storm::expressions::Expression::createDoubleVariable("d").ceil()) == storm::expressions::Expression::createIntegerVariable("i") && storm::expressions::Expression::createDoubleVariable("d") > storm::expressions::Expression::createDoubleLiteral(4.1) && storm::expressions::Expression::createDoubleVariable("d") < storm::expressions::Expression::createDoubleLiteral(4.991)); + z3::expr z3Ceil = ctx.int_val(5) == ctx.int_const("i"); + + //try { adapter.translateExpression(exprFloor, true); } + //catch (std::exception &e) { std::cout << e.what() << std::endl; } + + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_iff(ctx, adapter.translateExpression(exprCeil, true), z3Ceil))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::sat); //it is NOT logical equivalent + s.reset(); + ASSERT_NO_THROW(conjecture = !z3::expr(ctx, Z3_mk_implies(ctx, adapter.translateExpression(exprCeil, true), z3Ceil))); + s.add(conjecture); + ASSERT_TRUE(s.check() == z3::unsat); //it is NOT logical equivalent + s.reset(); +} + +TEST(Z3ExpressionAdapter, Z3ToStormBasic) { + z3::context ctx; + + unsigned args = 2; + + storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map()); + + z3::expr z3True = ctx.bool_val(true); + storm::expressions::Expression exprTrue; + exprTrue = adapter.translateExpression(z3True); + ASSERT_TRUE(exprTrue.isTrue()); + + z3::expr z3False = ctx.bool_val(false); + storm::expressions::Expression exprFalse; + exprFalse = adapter.translateExpression(z3False); + ASSERT_TRUE(exprFalse.isFalse()); + + z3::expr z3Conjunction = (ctx.bool_const("x") && ctx.bool_const("y")); + storm::expressions::Expression exprConjunction; + (exprConjunction = adapter.translateExpression(z3Conjunction)); + ASSERT_EQ(storm::expressions::OperatorType::And, exprConjunction.getOperator()); + ASSERT_TRUE(exprConjunction.getOperand(0).isVariable()); + ASSERT_EQ("x", exprConjunction.getOperand(0).getIdentifier()); + ASSERT_TRUE(exprConjunction.getOperand(1).isVariable()); + ASSERT_EQ("y", exprConjunction.getOperand(1).getIdentifier()); + + z3::expr z3Nor = !(ctx.bool_const("x") || ctx.bool_const("y")); + storm::expressions::Expression exprNor; + (exprNor = adapter.translateExpression(z3Nor)); + ASSERT_EQ(storm::expressions::OperatorType::Not, exprNor.getOperator()); + ASSERT_EQ(storm::expressions::OperatorType::Or, exprNor.getOperand(0).getOperator()); + ASSERT_TRUE(exprNor.getOperand(0).getOperand(0).isVariable()); + ASSERT_EQ("x", exprNor.getOperand(0).getOperand(0).getIdentifier()); + ASSERT_TRUE(exprNor.getOperand(0).getOperand(1).isVariable()); + ASSERT_EQ("y", exprNor.getOperand(0).getOperand(1).getIdentifier()); +} +#endif \ No newline at end of file diff --git a/test/functional/parser/PrismParserTest.cpp b/test/functional/parser/PrismParserTest.cpp index 3621149b8..dbbf5b4a4 100644 --- a/test/functional/parser/PrismParserTest.cpp +++ b/test/functional/parser/PrismParserTest.cpp @@ -94,7 +94,7 @@ TEST(PrismParser, ComplexTest) { endrewards)"; storm::prism::Program result; - result = storm::parser::PrismParser::parseFromString(testInput, "testfile"); + EXPECT_NO_THROW(result = storm::parser::PrismParser::parseFromString(testInput, "testfile")); EXPECT_EQ(storm::prism::Program::ModelType::MA, result.getModelType()); EXPECT_EQ(3, result.getNumberOfModules()); EXPECT_EQ(2, result.getNumberOfRewardModels()); diff --git a/test/functional/solver/GlpkLpSolverTest.cpp b/test/functional/solver/GlpkLpSolverTest.cpp index a6410611d..32048bab9 100644 --- a/test/functional/solver/GlpkLpSolverTest.cpp +++ b/test/functional/solver/GlpkLpSolverTest.cpp @@ -1,13 +1,13 @@ #include "gtest/gtest.h" #include "storm-config.h" +#ifdef STORM_HAVE_GLPK #include "src/solver/GlpkLpSolver.h" #include "src/exceptions/InvalidStateException.h" #include "src/exceptions/InvalidAccessException.h" #include "src/settings/Settings.h" TEST(GlpkLpSolver, LPOptimizeMax) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -35,13 +35,9 @@ TEST(GlpkLpSolver, LPOptimizeMax) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - 14.75), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, LPOptimizeMin) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Minimize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -69,13 +65,9 @@ TEST(GlpkLpSolver, LPOptimizeMin) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - (-6.7)), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, MILPOptimizeMax) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedIntegerVariable("y", 0, 2)); @@ -103,13 +95,9 @@ TEST(GlpkLpSolver, MILPOptimizeMax) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - 14), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, MILPOptimizeMin) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Minimize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedIntegerVariable("y", 0, 2)); @@ -137,13 +125,9 @@ TEST(GlpkLpSolver, MILPOptimizeMin) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - (-6)), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, LPInfeasible) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -168,13 +152,9 @@ TEST(GlpkLpSolver, LPInfeasible) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, MILPInfeasible) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -199,13 +179,9 @@ TEST(GlpkLpSolver, MILPInfeasible) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, LPUnbounded) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -228,13 +204,9 @@ TEST(GlpkLpSolver, LPUnbounded) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } TEST(GlpkLpSolver, MILPUnbounded) { -#ifdef STORM_HAVE_GLPK storm::solver::GlpkLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -257,7 +229,5 @@ TEST(GlpkLpSolver, MILPUnbounded) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without glpk support."; -#endif } +#endif \ No newline at end of file diff --git a/test/functional/solver/GurobiLpSolverTest.cpp b/test/functional/solver/GurobiLpSolverTest.cpp index ee0ebfa03..a862fac42 100644 --- a/test/functional/solver/GurobiLpSolverTest.cpp +++ b/test/functional/solver/GurobiLpSolverTest.cpp @@ -1,13 +1,13 @@ #include "gtest/gtest.h" #include "storm-config.h" +#ifdef STORM_HAVE_GUROBI #include "src/solver/GurobiLpSolver.h" #include "src/exceptions/InvalidStateException.h" #include "src/exceptions/InvalidAccessException.h" #include "src/settings/Settings.h" TEST(GurobiLpSolver, LPOptimizeMax) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -35,13 +35,9 @@ TEST(GurobiLpSolver, LPOptimizeMax) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - 14.75), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, LPOptimizeMin) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Minimize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedIntegerVariable("y", 0, 2)); @@ -69,13 +65,9 @@ TEST(GurobiLpSolver, LPOptimizeMin) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - (-6.7)), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, MILPOptimizeMax) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedIntegerVariable("y", 0, 2)); @@ -103,13 +95,9 @@ TEST(GurobiLpSolver, MILPOptimizeMax) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - 14), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, MILPOptimizeMin) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Minimize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedIntegerVariable("y", 0, 2)); @@ -137,13 +125,9 @@ TEST(GurobiLpSolver, MILPOptimizeMin) { double objectiveValue = 0; ASSERT_NO_THROW(objectiveValue = solver.getObjectiveValue()); ASSERT_LT(std::abs(objectiveValue - (-6)), storm::settings::Settings::getInstance()->getOptionByLongName("precision").getArgument(0).getValueAsDouble()); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, LPInfeasible) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -168,13 +152,9 @@ TEST(GurobiLpSolver, LPInfeasible) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, MILPInfeasible) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -199,13 +179,9 @@ TEST(GurobiLpSolver, MILPInfeasible) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, LPUnbounded) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBoundedContinuousVariable("x", 0, 1, -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -228,13 +204,9 @@ TEST(GurobiLpSolver, LPUnbounded) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } TEST(GurobiLpSolver, MILPUnbounded) { -#ifdef STORM_HAVE_GUROBI storm::solver::GurobiLpSolver solver(storm::solver::LpSolver::ModelSense::Maximize); ASSERT_NO_THROW(solver.addBinaryVariable("x", -1)); ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("y", 0, 2)); @@ -256,7 +228,5 @@ TEST(GurobiLpSolver, MILPUnbounded) { ASSERT_THROW(zValue = solver.getContinuousValue("z"), storm::exceptions::InvalidAccessException); double objectiveValue = 0; ASSERT_THROW(objectiveValue = solver.getObjectiveValue(), storm::exceptions::InvalidAccessException); -#else - ASSERT_TRUE(false) << "StoRM built without Gurobi support."; -#endif } +#endif \ No newline at end of file diff --git a/test/functional/solver/Z3SmtSolverTest.cpp b/test/functional/solver/Z3SmtSolverTest.cpp new file mode 100644 index 000000000..c3222404c --- /dev/null +++ b/test/functional/solver/Z3SmtSolverTest.cpp @@ -0,0 +1,226 @@ +#include "gtest/gtest.h" +#include "storm-config.h" + +#ifdef STORM_HAVE_Z3 +#include "src/solver/Z3SmtSolver.h" +#include "src/settings/Settings.h" + +TEST(Z3SmtSolver, CheckSat) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression exprDeMorgan = !(storm::expressions::Expression::createBooleanVariable("x") && storm::expressions::Expression::createBooleanVariable("y")).iff((!storm::expressions::Expression::createBooleanVariable("x") || !storm::expressions::Expression::createBooleanVariable("y"))); + + ASSERT_NO_THROW(s.assertExpression(exprDeMorgan)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.reset()); + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a >= storm::expressions::Expression::createIntegerLiteral(0) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + + ASSERT_NO_THROW(s.assertExpression(exprFormula)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.reset()); +} + +TEST(Z3SmtSolver, CheckUnsat) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression exprDeMorgan = !(storm::expressions::Expression::createBooleanVariable("x") && storm::expressions::Expression::createBooleanVariable("y")).iff( (!storm::expressions::Expression::createBooleanVariable("x") || !storm::expressions::Expression::createBooleanVariable("y"))); + + ASSERT_NO_THROW(s.assertExpression(!exprDeMorgan)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(s.reset()); + + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a >= storm::expressions::Expression::createIntegerLiteral(2) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + + ASSERT_NO_THROW(s.assertExpression(exprFormula)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); +} + + +TEST(Z3SmtSolver, Backtracking) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression expr1 = storm::expressions::Expression::createTrue(); + storm::expressions::Expression expr2 = storm::expressions::Expression::createFalse(); + storm::expressions::Expression expr3 = storm::expressions::Expression::createFalse(); + + ASSERT_NO_THROW(s.assertExpression(expr1)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.push()); + ASSERT_NO_THROW(s.assertExpression(expr2)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(s.pop()); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.push()); + ASSERT_NO_THROW(s.assertExpression(expr2)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(s.push()); + ASSERT_NO_THROW(s.assertExpression(expr3)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(s.pop(2)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.reset()); + + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a >= storm::expressions::Expression::createIntegerLiteral(0) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + storm::expressions::Expression exprFormula2 = a >= storm::expressions::Expression::createIntegerLiteral(2); + + ASSERT_NO_THROW(s.assertExpression(exprFormula)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.push()); + ASSERT_NO_THROW(s.assertExpression(exprFormula2)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(s.pop()); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); +} + +TEST(Z3SmtSolver, Assumptions) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a >= storm::expressions::Expression::createIntegerLiteral(0) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + storm::expressions::Expression f2 = storm::expressions::Expression::createBooleanVariable("f2"); + storm::expressions::Expression exprFormula2 = f2.implies(a >= storm::expressions::Expression::createIntegerLiteral(2)); + + ASSERT_NO_THROW(s.assertExpression(exprFormula)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(s.assertExpression(exprFormula2)); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(result = s.checkWithAssumptions({f2})); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + ASSERT_NO_THROW(result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + ASSERT_NO_THROW(result = s.checkWithAssumptions({ !f2 })); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); +} + +TEST(Z3SmtSolver, GenerateModel) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a > storm::expressions::Expression::createIntegerLiteral(0) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + + (s.assertExpression(exprFormula)); + (result = s.check()); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::SAT); + storm::expressions::SimpleValuation model; + (model = s.getModel()); + int_fast64_t a_eval; + (a_eval = model.getIntegerValue("a")); + ASSERT_EQ(1, a_eval); +} + + +TEST(Z3SmtSolver, AllSat) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression x = storm::expressions::Expression::createBooleanVariable("x"); + storm::expressions::Expression y = storm::expressions::Expression::createBooleanVariable("y"); + storm::expressions::Expression z = storm::expressions::Expression::createBooleanVariable("z"); + storm::expressions::Expression exprFormula1 = x.implies(a > storm::expressions::Expression::createIntegerLiteral(5)); + storm::expressions::Expression exprFormula2 = y.implies(a < storm::expressions::Expression::createIntegerLiteral(5)); + storm::expressions::Expression exprFormula3 = z.implies(b < storm::expressions::Expression::createIntegerLiteral(5)); + + (s.assertExpression(exprFormula1)); + (s.assertExpression(exprFormula2)); + (s.assertExpression(exprFormula3)); + + std::vector valuations = s.allSat({x,y}); + + ASSERT_TRUE(valuations.size() == 3); + for (int i = 0; i < valuations.size(); ++i) { + ASSERT_EQ(valuations[i].getNumberOfIdentifiers(), 2); + ASSERT_TRUE(valuations[i].containsBooleanIdentifier("x")); + ASSERT_TRUE(valuations[i].containsBooleanIdentifier("y")); + } + for (int i = 0; i < valuations.size(); ++i) { + ASSERT_FALSE(valuations[i].getBooleanValue("x") && valuations[i].getBooleanValue("y")); + + for (int j = i+1; j < valuations.size(); ++j) { + ASSERT_TRUE((valuations[i].getBooleanValue("x") != valuations[j].getBooleanValue("x")) || (valuations[i].getBooleanValue("y") != valuations[j].getBooleanValue("y"))); + } + } +} + +TEST(Z3SmtSolver, UnsatAssumptions) { + storm::solver::Z3SmtSolver s; + storm::solver::Z3SmtSolver::CheckResult result; + + storm::expressions::Expression a = storm::expressions::Expression::createIntegerVariable("a"); + storm::expressions::Expression b = storm::expressions::Expression::createIntegerVariable("b"); + storm::expressions::Expression c = storm::expressions::Expression::createIntegerVariable("c"); + storm::expressions::Expression exprFormula = a >= storm::expressions::Expression::createIntegerLiteral(0) + && a < storm::expressions::Expression::createIntegerLiteral(5) + && b > storm::expressions::Expression::createIntegerLiteral(7) + && c == (a * b) + && b + a > c; + storm::expressions::Expression f2 = storm::expressions::Expression::createBooleanVariable("f2"); + storm::expressions::Expression exprFormula2 = f2.implies(a >= storm::expressions::Expression::createIntegerLiteral(2)); + + (s.assertExpression(exprFormula)); + (s.assertExpression(exprFormula2)); + (result = s.checkWithAssumptions({ f2 })); + ASSERT_TRUE(result == storm::solver::SmtSolver::CheckResult::UNSAT); + std::vector unsatCore = s.getUnsatAssumptions(); + ASSERT_EQ(unsatCore.size(), 1); + ASSERT_TRUE(unsatCore[0].isVariable()); + ASSERT_STREQ("f2", unsatCore[0].getIdentifier().c_str()); +} +#endif \ No newline at end of file