Browse Source

Started refactoring MathSAT adapter.

Former-commit-id: 93b1fdedb3
tempestpy_adaptions
dehnert 10 years ago
parent
commit
a061cdbed8
  1. 236
      src/adapters/MathsatExpressionAdapter.h
  2. 11
      src/adapters/Z3ExpressionAdapter.h
  3. 12
      test/functional/adapter/Z3ExpressionAdapterTest.cpp

236
src/adapters/MathsatExpressionAdapter.h

@ -1,9 +1,3 @@
/*
* MathSatExpressionAdapter.h
*
* Author: David Korzeniewski
*/
#ifndef STORM_ADAPTERS_MATHSATEXPRESSIONADAPTER_H_
#define STORM_ADAPTERS_MATHSATEXPRESSIONADAPTER_H_
@ -24,71 +18,37 @@
namespace storm {
namespace adapters {
#ifdef STORM_HAVE_MSAT
class MathsatExpressionAdapter : public storm::expressions::ExpressionVisitor {
public:
/*!
* Creates a MathsatExpressionAdapter over the given MathSAT enviroment.
* Creates an expression adapter that can translate expressions to the format of Z3.
*
* @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence,
* having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is
* strictly to be avoided.
*
* @param context A reference to the MathSAT enviroment over which to build the expressions. Be careful to guarantee
* the lifetime of the context as long as the instance of this adapter is used.
* @param variableToDeclMap A mapping from variable names to their corresponding MathSAT Declarations.
* @param context A reference to the Z3 context over which to build the expressions. The lifetime of the
* context needs to be guaranteed as long as the instance of this adapter is used.
* @param createVariables If set to true, additional variables will be created for variables that appear in
* expressions and are not yet known to the adapter.
* @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing).
*/
MathsatExpressionAdapter(msat_env& env, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap) {
MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap) {
// Intentionally left empty.
}
/*!
* Translates the given expression to an equivalent term for MathSAT.
*
* @param expression The expression to translate.
* @param createMathSatVariables If set to true a solver variable is created for each variable in expression that is not
* yet known to the adapter. (i.e. values from the variableToExpressionMap passed to the constructor
* are not overwritten)
* @param expression The expression to be translated.
* @return An equivalent term for MathSAT.
*/
msat_term translateExpression(storm::expressions::Expression const& expression, bool createMathSatVariables = false) {
if (createMathSatVariables) {
std::map<std::string, storm::expressions::ExpressionReturnType> variables;
try {
variables = expression.getVariablesAndTypes();
}
catch (storm::exceptions::InvalidTypeException* e) {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with ambigious type while trying to autocreate solver variables: " << e);
}
for (auto variableAndType : variables) {
if (this->variableToDeclarationMap.find(variableAndType.first) == this->variableToDeclarationMap.end()) {
switch (variableAndType.second)
{
case storm::expressions::ExpressionReturnType::Bool:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_bool_type(env))));
break;
case storm::expressions::ExpressionReturnType::Int:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_integer_type(env))));
break;
case storm::expressions::ExpressionReturnType::Double:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_rational_type(env))));
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with unknown type while trying to autocreate solver variables: " << variableAndType.first);
break;
}
}
}
}
//LOG4CPLUS_TRACE(logger, "Translating expression:\n" << expression->toString());
msat_term translateExpression(storm::expressions::Expression const& expression) {
expression.getBaseExpression().accept(this);
msat_term result = stack.top();
stack.pop();
if (MSAT_ERROR_TERM(result)) {
//LOG4CPLUS_WARN(logger, "Translating term to MathSAT returned an error!");
}
char* repr = msat_term_repr(result);
//LOG4CPLUS_TRACE(logger, "Result is:\n" << repr);
msat_free(repr);
STORM_LOG_THROW(!MSAT_ERROR_TERM(result), storm::exceptions::ExpressionEvaluationException, "Could not translate expression to MathSAT's format.");
return result;
}
@ -101,13 +61,6 @@ namespace storm {
msat_term leftResult = stack.top();
stack.pop();
//char* repr = msat_term_repr(leftResult);
//LOG4CPLUS_TRACE(logger, "LHS: "<<repr);
//msat_free(repr);
//repr = msat_term_repr(rightResult);
//LOG4CPLUS_TRACE(logger, "RHS: "<<repr);
//msat_free(repr);
switch (expression->getOperatorType()) {
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And:
stack.push(msat_make_and(env, leftResult, rightResult));
@ -121,7 +74,8 @@ namespace storm {
case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies:
stack.push(msat_make_or(env, msat_make_not(env, leftResult), rightResult));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
default:
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".";
}
@ -260,6 +214,37 @@ namespace storm {
}
virtual void visit(expressions::VariableExpression const* expression) override {
if (createMathSatVariables) {
std::map<std::string, storm::expressions::ExpressionReturnType> variables;
try {
variables = expression.getVariablesAndTypes();
}
catch (storm::exceptions::InvalidTypeException* e) {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with ambigious type while trying to autocreate solver variables: " << e);
}
for (auto variableAndType : variables) {
if (this->variableToDeclarationMap.find(variableAndType.first) == this->variableToDeclarationMap.end()) {
switch (variableAndType.second)
{
case storm::expressions::ExpressionReturnType::Bool:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_bool_type(env))));
break;
case storm::expressions::ExpressionReturnType::Int:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_integer_type(env))));
break;
case storm::expressions::ExpressionReturnType::Double:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_rational_type(env))));
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with unknown type while trying to autocreate solver variables: " << variableAndType.first);
break;
}
}
}
}
STORM_LOG_THROW(variableToDeclarationMap.count(expression->getVariableName()) != 0, storm::exceptions::InvalidArgumentException, "Variable '" << expression->getVariableName() << "' is unknown.");
//LOG4CPLUS_TRACE(logger, "Variable "<<expression->getVariableName());
//char* repr = msat_decl_repr(variableToDeclMap.at(expression->getVariableName()));
@ -271,132 +256,71 @@ namespace storm {
stack.push(msat_make_constant(env, variableToDeclarationMap.at(expression->getVariableName())));
}
storm::expressions::Expression translateTerm(msat_term term) {
this->processTerm(term);
storm::expressions::Expression result = std::move(expression_stack.top());
expression_stack.pop();
return result;
}
void processTerm(msat_term term) {
storm::expressions::Expression translateExpression(msat_term const& term) {
if (msat_term_is_and(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult &&rightResult);
return translateExpression(msat_term_get_arg(term, 0)) && translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_or(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult && rightResult);
return translateExpression(msat_term_get_arg(term, 0)) || translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_iff(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult.iff(rightResult));
return translateExpression(msat_term_get_arg(term, 0)).iff(translateExpression(msat_term_get_arg(term, 1)));
} else if (msat_term_is_not(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
storm::expressions::Expression childResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(!childResult);
return !translateExpression(msat_term_get_arg(term, 0));
} else if (msat_term_is_plus(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult+rightResult);
return translateExpression(msat_term_get_arg(term, 0)) + translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_times(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult * rightResult);
return translateExpression(msat_term_get_arg(term, 0)) * translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_equal(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult == rightResult);
return translateExpression(msat_term_get_arg(term, 0)) == translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_leq(env, term)) {
this->processTerm(msat_term_get_arg(term, 0));
this->processTerm(msat_term_get_arg(term, 1));
storm::expressions::Expression rightResult = std::move(expression_stack.top());
expression_stack.pop();
storm::expressions::Expression leftResult = std::move(expression_stack.top());
expression_stack.pop();
expression_stack.push(leftResult <= rightResult);
return translateExpression(msat_term_get_arg(term, 0)) <= translateExpression(msat_term_get_arg(term, 1));
} else if (msat_term_is_true(env, term)) {
expression_stack.push(expressions::Expression::createTrue());
return storm::expressions::Expression::createTrue();
} else if (msat_term_is_false(env, term)) {
expression_stack.push(expressions::Expression::createFalse());
return storm::expressions::Expression::createFalse();
} else if (msat_term_is_boolean_constant(env, term)) {
char* name = msat_decl_get_name(msat_term_get_decl(term));
std::string name_str(name);
expression_stack.push(expressions::Expression::createBooleanVariable(name_str.substr(0, name_str.find('/'))));
storm::expressions::Expression result = expressions::Expression::createBooleanVariable(name_str.substr(0, name_str.find('/')));
msat_free(name);
return result;
} else if (msat_term_is_constant(env, term)) {
char* name = msat_decl_get_name(msat_term_get_decl(term));
std::string name_str(name);
storm::expressions::Expression result;
if (msat_is_integer_type(env, msat_term_get_type(term))) {
expression_stack.push(expressions::Expression::createIntegerVariable(name_str.substr(0, name_str.find('/'))));
result = expressions::Expression::createIntegerVariable(name_str.substr(0, name_str.find('/')));
} else if (msat_is_rational_type(env, msat_term_get_type(term))) {
expression_stack.push(expressions::Expression::createDoubleVariable(name_str.substr(0, name_str.find('/'))));
result = expressions::Expression::createDoubleVariable(name_str.substr(0, name_str.find('/')));
}
msat_free(name);
return result;
} else if (msat_term_is_number(env, term)) {
if (msat_is_integer_type(env, msat_term_get_type(term))) {
expression_stack.push(expressions::Expression::createIntegerLiteral(std::stoll(msat_term_repr(term))));
return expressions::Expression::createIntegerLiteral(std::stoll(msat_term_repr(term)));
} else if (msat_is_rational_type(env, msat_term_get_type(term))) {
expression_stack.push(expressions::Expression::createDoubleLiteral(std::stod(msat_term_repr(term))));
return expressions::Expression::createDoubleLiteral(std::stod(msat_term_repr(term)));
}
} else {
char* term_cstr = msat_term_repr(term);
std::string term_str(term_cstr);
msat_free(term_cstr);
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown term: '" << term_str << "'.";
}
// If all other cases did not apply, we cannot represent the term in our expression framework.
char* termAsCString = msat_term_repr(term);
std::string termString(termAsCString);
msat_free(termAsCString);
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot translate expression: unknown term: '" << termString << "'.");
}
private:
// The MathSAT environment used.
msat_env& env;
// A stack used for communicating results between different functions.
std::stack<msat_term> stack;
std::stack<expressions::Expression> expression_stack;
// A mapping of variable names to their declaration in the MathSAT environment.
std::map<std::string, msat_decl> variableToDeclarationMap;
};
#endif
} // namespace adapters
} // namespace storm

11
src/adapters/Z3ExpressionAdapter.h

@ -1,10 +1,3 @@
/*
* Z3ExpressionAdapter.h
*
* Created on: 04.10.2013
* Author: Christian Dehnert
*/
#ifndef STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
#define STORM_ADAPTERS_Z3EXPRESSIONADAPTER_H_
@ -38,11 +31,11 @@ namespace storm {
*
* @param context A reference to the Z3 context over which to build the expressions. The lifetime of the
* context needs to be guaranteed as long as the instance of this adapter is used.
* @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions (if already existing).
* @param createVariables If set to true, additional variables will be created for variables that appear in
* expressions and are not yet known to the adapter.
* @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions (if already existing).
*/
Z3ExpressionAdapter(z3::context& context, std::map<std::string, z3::expr> const& variableToExpressionMap = std::map<std::string, z3::expr>(), bool createVariables = false) : context(context) , stack() , additionalAssertions() , additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) {
Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map<std::string, z3::expr> const& variableToExpressionMap = std::map<std::string, z3::expr>()) : context(context) , stack() , additionalAssertions() , additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) {
// Intentionally left empty.
}

12
test/functional/adapter/Z3ExpressionAdapterTest.cpp

@ -11,8 +11,8 @@ TEST(Z3ExpressionAdapter, StormToZ3Basic) {
z3::solver s(ctx);
z3::expr conjecture = ctx.bool_val(false);
storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map<std::string, z3::expr>(), true);
storm::adapters::Z3ExpressionAdapter adapter2(ctx, std::map<std::string, z3::expr>(), false);
storm::adapters::Z3ExpressionAdapter adapter(ctx);
storm::adapters::Z3ExpressionAdapter adapter2(ctx, false);
storm::expressions::Expression exprTrue = storm::expressions::Expression::createTrue();
z3::expr z3True = ctx.bool_val(true);
@ -51,7 +51,7 @@ TEST(Z3ExpressionAdapter, StormToZ3Integer) {
z3::solver s(ctx);
z3::expr conjecture = ctx.bool_val(false);
storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map<std::string, z3::expr>(), true);
storm::adapters::Z3ExpressionAdapter adapter(ctx, true);
storm::expressions::Expression exprAdd = (storm::expressions::Expression::createIntegerVariable("x") + storm::expressions::Expression::createIntegerVariable("y") < -storm::expressions::Expression::createIntegerVariable("y"));
z3::expr z3Add = (ctx.int_const("x") + ctx.int_const("y") < -ctx.int_const("y"));
@ -73,7 +73,7 @@ TEST(Z3ExpressionAdapter, StormToZ3Real) {
z3::solver s(ctx);
z3::expr conjecture = ctx.bool_val(false);
storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map<std::string, z3::expr>(), true);
storm::adapters::Z3ExpressionAdapter adapter(ctx);
storm::expressions::Expression exprAdd = (storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") < -storm::expressions::Expression::createDoubleVariable("y"));
z3::expr z3Add = (ctx.real_const("x") + ctx.real_const("y") < -ctx.real_const("y"));
@ -95,7 +95,7 @@ TEST(Z3ExpressionAdapter, StormToZ3FloorCeil) {
z3::solver s(ctx);
z3::expr conjecture = ctx.bool_val(false);
storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map<std::string, z3::expr>(), true);
storm::adapters::Z3ExpressionAdapter adapter(ctx);
storm::expressions::Expression exprFloor = ((storm::expressions::Expression::createDoubleVariable("d").floor()) == storm::expressions::Expression::createIntegerVariable("i") && storm::expressions::Expression::createDoubleVariable("d") > storm::expressions::Expression::createDoubleLiteral(4.1) && storm::expressions::Expression::createDoubleVariable("d") < storm::expressions::Expression::createDoubleLiteral(4.991));
z3::expr z3Floor = ctx.int_val(4) == ctx.int_const("i");
@ -133,7 +133,7 @@ TEST(Z3ExpressionAdapter, Z3ToStormBasic) {
z3::context ctx;
unsigned args = 2;
storm::adapters::Z3ExpressionAdapter adapter(ctx, std::map<std::string, z3::expr>());
storm::adapters::Z3ExpressionAdapter adapter(ctx);
z3::expr z3True = ctx.bool_val(true);
storm::expressions::Expression exprTrue;

Loading…
Cancel
Save