diff --git a/src/adapters/MathSatExpressionAdapter.h b/src/adapters/MathSatExpressionAdapter.h index 267337075..949ae46c5 100644 --- a/src/adapters/MathSatExpressionAdapter.h +++ b/src/adapters/MathSatExpressionAdapter.h @@ -47,7 +47,7 @@ namespace storm { */ msat_term translateExpression(storm::expressions::Expression const& expression, bool createMathSatVariables = false) { //LOG4CPLUS_TRACE(logger, "Translating expression:\n" << expression->toString()); - expression.accept(this); + expression.getBaseExpression().accept(this); msat_term result = stack.top(); stack.pop(); if (MSAT_ERROR_TERM(result)) { @@ -87,7 +87,7 @@ namespace storm { stack.push(msat_make_iff(env, leftResult, rightResult)); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getOperatorType() << "' in expression " << expression << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } @@ -122,7 +122,7 @@ namespace storm { stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult)); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical binary operator: '" << expression->getOperatorType() << "' in expression " << expression << "."; + << "Unknown numerical binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } @@ -163,7 +163,7 @@ namespace storm { stack.push(msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult)))); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getRelationType() << "' in expression " << expression << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."; } } @@ -205,7 +205,7 @@ namespace storm { stack.push(msat_make_not(env, childResult)); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << expression->getOperatorType() << "' in expression " << expression << "."; + << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; } } @@ -220,7 +220,7 @@ namespace storm { stack.push(msat_make_times(env, msat_make_number(env, "-1"), childResult)); break; default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical unary operator: '" << expression->getOperatorType() << "'."; + << "Unknown numerical unary operator: '" << static_cast(expression->getOperatorType()) << "'."; } } diff --git a/src/solver/MathSatSmtSolver.cpp b/src/solver/MathSatSmtSolver.cpp index 7d8ee19d1..f03ee3d6e 100644 --- a/src/solver/MathSatSmtSolver.cpp +++ b/src/solver/MathSatSmtSolver.cpp @@ -217,7 +217,13 @@ namespace storm { #ifdef STORM_HAVE_MSAT - struct AllsatValuationsCallbackUserData { + class AllsatValuationsCallbackUserData { + public: + AllsatValuationsCallbackUserData(msat_env &env, + storm::adapters::MathSatExpressionAdapter &adapter, + std::function &callback) + : env(env), adapter(adapter), callback(callback) { + } msat_env &env; storm::adapters::MathSatExpressionAdapter &adapter; std::function &callback; @@ -240,6 +246,12 @@ namespace storm { valuation.addBooleanIdentifier(name_str, currentTermValue); msat_free(name); } + + if (user->callback(valuation)) { + return 1; + } else { + return 0; + } } #endif @@ -257,10 +269,7 @@ namespace storm { msatImportant.push_back(m_adapter->translateExpression(e)); } - AllsatValuationsCallbackUserData allSatUserData; - allSatUserData.adapter = m_adapter; - allSatUserData.env = m_env; - allSatUserData.callback = callback; + AllsatValuationsCallbackUserData allSatUserData(m_env, *m_adapter, callback); int numModels = msat_all_sat(m_env, msatImportant.data(), msatImportant.size(), &allsatValuationsCallback, &allSatUserData); return numModels;