#ifndef STORM_ADAPTERS_MATHSATEXPRESSIONADAPTER_H_ #define STORM_ADAPTERS_MATHSATEXPRESSIONADAPTER_H_ #include "storm-config.h" #include #ifdef STORM_HAVE_MSAT #include "mathsat.h" #endif #include "storage/expressions/ExpressionManager.h" #include "storage/expressions/Expressions.h" #include "storage/expressions/ExpressionVisitor.h" #include "src/utility/macros.h" #include "src/exceptions/ExpressionEvaluationException.h" #include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/NotImplementedException.h" #ifdef STORM_HAVE_MSAT namespace std { // Define hashing operator for MathSAT's declarations. template <> struct hash { size_t operator()(msat_decl const& declaration) const { return hash()(declaration.repr); } }; } // Define equality operator to make hashing work. bool operator==(msat_decl decl1, msat_decl decl2); #endif namespace storm { namespace adapters { #ifdef STORM_HAVE_MSAT class MathsatExpressionAdapter : public storm::expressions::ExpressionVisitor { public: /*! * Creates an expression adapter that can translate expressions to the format of MathSAT. * * @param manager The manager that can be used to build expressions. * @param env The MathSAT environment in which to build the expressions. */ MathsatExpressionAdapter(storm::expressions::ExpressionManager& manager, msat_env& env) : manager(manager), env(env), variableToDeclarationMapping() { // Intentionally left empty. } /*! * Translates the given expression to an equivalent term for MathSAT. * * @param expression The expression to be translated. * @return An equivalent term for MathSAT. */ msat_term translateExpression(storm::expressions::Expression const& expression) { msat_term result = boost::any_cast(expression.getBaseExpression().accept(*this)); STORM_LOG_THROW(!MSAT_ERROR_TERM(result), storm::exceptions::ExpressionEvaluationException, "Could not translate expression to MathSAT's format."); return result; } /*! * Translates the given variable to an equivalent expression for Z3. * * @param variable The variable to translate. * @return An equivalent term for MathSAT. */ msat_term translateExpression(storm::expressions::Variable const& variable) { STORM_LOG_ASSERT(variable.getManager() == this->manager, "Invalid expression for solver."); auto const& variableExpressionPair = variableToDeclarationMapping.find(variable); if (variableExpressionPair == variableToDeclarationMapping.end()) { return msat_make_constant(env, createVariable(variable)); } return msat_make_constant(env, variableExpressionPair->second); } /*! * Retrieves the variable that is associated with the given MathSAT variable declaration. * * @param msatVariableDeclaration The MathSAT variable declaration. * @return The variable associated with the given declaration. */ storm::expressions::Variable const& getVariable(msat_decl msatVariableDeclaration) const { auto const& declarationVariablePair = declarationToVariableMapping.find(msatVariableDeclaration); STORM_LOG_ASSERT(declarationVariablePair != declarationToVariableMapping.end(), "Unknown variable declaration."); return declarationVariablePair->second; } virtual boost::any visit(expressions::BinaryBooleanFunctionExpression const& expression) override { msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); switch (expression.getOperatorType()) { case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: return msat_make_and(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: return msat_make_or(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: return msat_make_iff(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: return msat_make_or(env, msat_make_not(env, leftResult), rightResult); default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } virtual boost::any visit(expressions::BinaryNumericalFunctionExpression const& expression) override { msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); switch (expression.getOperatorType()) { case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: return msat_make_plus(env, leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: return msat_make_plus(env, leftResult, msat_make_times(env, msat_make_number(env, "-1"), rightResult)); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: return msat_make_times(env, leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unsupported numerical binary operator: '/' (division) in expression."); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult); default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } virtual boost::any visit(expressions::BinaryRelationExpression const& expression) override { msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); switch (expression.getRelationType()) { case storm::expressions::BinaryRelationExpression::RelationType::Equal: if (expression.getFirstOperand()->getType().isBooleanType() && expression.getSecondOperand()->getType().isBooleanType()) { return msat_make_iff(env, leftResult, rightResult); } else { return msat_make_equal(env, leftResult, rightResult); } case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: if (expression.getFirstOperand()->getType().isBooleanType() && expression.getSecondOperand()->getType().isBooleanType()) { return msat_make_not(env, msat_make_iff(env, leftResult, rightResult)); } else { return msat_make_not(env, msat_make_equal(env, leftResult, rightResult)); } case storm::expressions::BinaryRelationExpression::RelationType::Less: return msat_make_and(env, msat_make_not(env, msat_make_equal(env, leftResult, rightResult)), msat_make_leq(env, leftResult, rightResult)); case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: return msat_make_leq(env, leftResult, rightResult); case storm::expressions::BinaryRelationExpression::RelationType::Greater: return msat_make_not(env, msat_make_leq(env, leftResult, rightResult)); case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: return msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult))); default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getRelationType()) << "' in expression " << expression << "."); } } virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression) override { msat_term conditionResult = boost::any_cast(expression.getCondition()->accept(*this)); msat_term thenResult = boost::any_cast(expression.getThenExpression()->accept(*this)); msat_term elseResult = boost::any_cast(expression.getElseExpression()->accept(*this)); return msat_make_term_ite(env, conditionResult, thenResult, elseResult); } virtual boost::any visit(expressions::BooleanLiteralExpression const& expression) override { return expression.getValue() ? msat_make_true(env) : msat_make_false(env); } virtual boost::any visit(expressions::DoubleLiteralExpression const& expression) override { return msat_make_number(env, std::to_string(expression.getValueAsDouble()).c_str()); } virtual boost::any visit(expressions::IntegerLiteralExpression const& expression) override { return msat_make_number(env, std::to_string(static_cast(expression.getValue())).c_str()); } virtual boost::any visit(expressions::UnaryBooleanFunctionExpression const& expression) override { msat_term childResult = boost::any_cast(expression.getOperand()->accept(*this)); switch (expression.getOperatorType()) { case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: return msat_make_not(env, childResult); break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } virtual boost::any visit(expressions::UnaryNumericalFunctionExpression const& expression) override { msat_term childResult = boost::any_cast(expression.getOperand()->accept(*this)); switch (expression.getOperatorType()) { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: return msat_make_times(env, msat_make_number(env, "-1"), childResult); break; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: return msat_make_floor(env, childResult); break; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil: return msat_make_plus(env, msat_make_floor(env, childResult), msat_make_number(env, "1")); break; default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } virtual boost::any visit(expressions::VariableExpression const& expression) override { return translateExpression(expression.getVariable()); } storm::expressions::Expression translateExpression(msat_term const& term) { if (msat_term_is_and(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) && translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_or(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) || translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_iff(env, term)) { return storm::expressions::iff(translateExpression(msat_term_get_arg(term, 0)), translateExpression(msat_term_get_arg(term, 1))); } else if (msat_term_is_not(env, term)) { return !translateExpression(msat_term_get_arg(term, 0)); } else if (msat_term_is_plus(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) + translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_times(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) * translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_equal(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) == translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_leq(env, term)) { return translateExpression(msat_term_get_arg(term, 0)) <= translateExpression(msat_term_get_arg(term, 1)); } else if (msat_term_is_true(env, term)) { return manager.boolean(true); } else if (msat_term_is_false(env, term)) { return manager.boolean(false); } else if (msat_term_is_constant(env, term)) { char* name = msat_decl_get_name(msat_term_get_decl(term)); std::string nameString(name); storm::expressions::Expression result = manager.getVariableExpression(nameString.substr(0, nameString.find('/'))); msat_free(name); return result; } else if (msat_term_is_number(env, term)) { if (msat_is_integer_type(env, msat_term_get_type(term))) { return manager.integer(std::stoll(msat_term_repr(term))); } else if (msat_is_rational_type(env, msat_term_get_type(term))) { return manager.rational(std::stod(msat_term_repr(term))); } } // If all other cases did not apply, we cannot represent the term in our expression framework. char* termAsCString = msat_term_repr(term); std::string termString(termAsCString); msat_free(termAsCString); STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot translate expression: unknown term: '" << termString << "'."); } private: /*! * Creates a MathSAT variable for the provided variable. * * @param variable The variable for which to create a MathSAT counterpart. */ msat_decl createVariable(storm::expressions::Variable const& variable) { msat_decl msatDeclaration; if (variable.getType().isBooleanType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_bool_type(env)); } else if (variable.getType().isIntegerType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_integer_type(env)); } else if (variable.getType().isBitVectorType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_bv_type(env, variable.getType().getWidth())); } else if (variable.getType().isRationalType()) { msatDeclaration = msat_declare_function(env, variable.getName().c_str(), msat_get_rational_type(env)); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << variable.getName() << "' with unknown type while trying to create solver variables."); } variableToDeclarationMapping.insert(std::make_pair(variable, msatDeclaration)); declarationToVariableMapping.insert(std::make_pair(msatDeclaration, variable)); return msatDeclaration; } // The expression manager to use. storm::expressions::ExpressionManager& manager; // The MathSAT environment used. msat_env& env; // A mapping of variable names to their declaration in the MathSAT environment. std::unordered_map variableToDeclarationMapping; // A mapping from MathSAT variable declaration to our variables. std::unordered_map declarationToVariableMapping; }; #endif } // namespace adapters } // namespace storm #endif /* STORM_ADAPTERS_MATHSATEXPRESSIONADAPTER_H_ */