Browse Source

Done refactoring MathSAT expression adapter.

Former-commit-id: 6edb98b86c
tempestpy_adaptions
dehnert 10 years ago
parent
commit
f54b5671ea
  1. 89
      src/adapters/MathsatExpressionAdapter.h
  2. 20
      src/solver/MathsatSmtSolver.cpp
  3. 8
      src/solver/Z3SmtSolver.cpp

89
src/adapters/MathsatExpressionAdapter.h

@ -34,7 +34,7 @@ namespace storm {
* expressions and are not yet known to the adapter.
* @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing).
*/
MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap) {
MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map<std::string, msat_decl> const& variableToDeclarationMap = std::map<std::string, msat_decl>()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) {
// Intentionally left empty.
}
@ -75,8 +75,7 @@ namespace storm {
stack.push(msat_make_or(env, msat_make_not(env, leftResult), rightResult));
break;
default:
throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".";
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
}
}
@ -110,8 +109,8 @@ namespace storm {
case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max:
stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown numerical binary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".";
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
}
}
@ -151,8 +150,8 @@ namespace storm {
case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual:
stack.push(msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult))));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << static_cast<uint_fast64_t>(expression->getRelationType()) << "' in expression " << expression << ".";
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast<uint_fast64_t>(expression->getRelationType()) << "' in expression " << expression << ".");
}
}
@ -193,8 +192,8 @@ namespace storm {
case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not:
stack.push(msat_make_not(env, childResult));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".";
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
}
}
@ -208,52 +207,39 @@ namespace storm {
case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus:
stack.push(msat_make_times(env, msat_make_number(env, "-1"), childResult));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown numerical unary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "'.";
default:
STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast<uint_fast64_t>(expression->getOperatorType()) << "' in expression " << expression << ".");
}
}
virtual void visit(expressions::VariableExpression const* expression) override {
if (createMathSatVariables) {
std::map<std::string, storm::expressions::ExpressionReturnType> variables;
try {
variables = expression.getVariablesAndTypes();
}
catch (storm::exceptions::InvalidTypeException* e) {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with ambigious type while trying to autocreate solver variables: " << e);
}
for (auto variableAndType : variables) {
if (this->variableToDeclarationMap.find(variableAndType.first) == this->variableToDeclarationMap.end()) {
switch (variableAndType.second)
{
case storm::expressions::ExpressionReturnType::Bool:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_bool_type(env))));
break;
case storm::expressions::ExpressionReturnType::Int:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_integer_type(env))));
break;
case storm::expressions::ExpressionReturnType::Double:
this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_rational_type(env))));
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with unknown type while trying to autocreate solver variables: " << variableAndType.first);
break;
}
}
}
}
std::map<std::string, msat_decl>::iterator stringVariablePair = variableToDeclarationMap.find(expression->getVariableName());
msat_decl result;
STORM_LOG_THROW(variableToDeclarationMap.count(expression->getVariableName()) != 0, storm::exceptions::InvalidArgumentException, "Variable '" << expression->getVariableName() << "' is unknown.");
//LOG4CPLUS_TRACE(logger, "Variable "<<expression->getVariableName());
//char* repr = msat_decl_repr(variableToDeclMap.at(expression->getVariableName()));
//LOG4CPLUS_TRACE(logger, "Decl: "<<repr);
//msat_free(repr);
if (MSAT_ERROR_DECL(variableToDeclarationMap.at(expression->getVariableName()))) {
STORM_LOG_WARN("Encountered an invalid MathSAT declaration.");
}
stack.push(msat_make_constant(env, variableToDeclarationMap.at(expression->getVariableName())));
if (stringVariablePair == variableToDeclarationMap.end() && createVariables) {
std::pair<std::map<std::string, msat_decl>::iterator, bool> iteratorAndFlag;
switch (expression->getReturnType()) {
case storm::expressions::ExpressionReturnType::Bool:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_bool_type(env))));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Int:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_integer_type(env))));
result = iteratorAndFlag.first->second;
break;
case storm::expressions::ExpressionReturnType::Double:
iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_rational_type(env))));
result = iteratorAndFlag.first->second;
break;
default:
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables.");
}
} else {
STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'.");
result = stringVariablePair->second;
}
stack.push(msat_make_constant(env, result));
}
storm::expressions::Expression translateExpression(msat_term const& term) {
@ -319,6 +305,9 @@ namespace storm {
// A mapping of variable names to their declaration in the MathSAT environment.
std::map<std::string, msat_decl> variableToDeclarationMap;
// A flag indicating whether variables are supposed to be created if they are not already known to the adapter.
bool createVariables;
};
#endif
} // namespace adapters

20
src/solver/MathsatSmtSolver.cpp

@ -38,28 +38,28 @@ namespace storm {
}
#endif
bool MathsatSmtSolver::MathsatModelReference::getBooleanValue(std::string const& name) const {
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false);
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name));
msat_term msatValue = msat_get_model_value(env, msatVariable);
STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval.");
storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue);
storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue);
STORM_LOG_THROW(value.hasBooleanReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve boolean value of non-boolean variable '" << name << "'.");
return value.evaluateAsBool();
}
int_fast64_t MathsatSmtSolver::MathsatModelReference::getIntegerValue(std::string const& name) const {
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false);
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name));
msat_term msatValue = msat_get_model_value(env, msatVariable);
STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval.");
storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue);
storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue);
STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve integer value of non-integer variable '" << name << "'.");
return value.evaluateAsInt();
}
double MathsatSmtSolver::MathsatModelReference::getDoubleValue(std::string const& name) const {
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false);
msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name));
msat_term msatValue = msat_get_model_value(env, msatVariable);
STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval.");
storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue);
storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue);
STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve double value of non-double variable '" << name << "'.");
return value.evaluateAsDouble();
}
@ -128,7 +128,7 @@ namespace storm {
void MathsatSmtSolver::add(storm::expressions::Expression const& e)
{
#ifdef STORM_HAVE_MSAT
msat_assert_formula(env, expressionAdapter->translateExpression(e, true));
msat_assert_formula(env, expressionAdapter->translateExpression(e));
#else
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without MathSAT support.");
#endif
@ -241,7 +241,7 @@ namespace storm {
msat_term t, v;
msat_model_iterator_next(modelIterator, &t, &v);
storm::expressions::Expression variableInterpretation = this->expressionAdapter->translateTerm(v);
storm::expressions::Expression variableInterpretation = this->expressionAdapter->translateExpression(v);
char* name = msat_decl_get_name(msat_term_get_decl(t));
switch (variableInterpretation.getReturnType()) {
@ -411,7 +411,7 @@ namespace storm {
unsatAssumptions.reserve(numUnsatAssumpations);
for (unsigned int i = 0; i < numUnsatAssumpations; ++i) {
unsatAssumptions.push_back(this->expressionAdapter->translateTerm(msatUnsatAssumptions[i]));
unsatAssumptions.push_back(this->expressionAdapter->translateExpression(msatUnsatAssumptions[i]));
}
return unsatAssumptions;
@ -450,7 +450,7 @@ namespace storm {
STORM_LOG_THROW(!MSAT_ERROR_TERM(interpolant), storm::exceptions::UnexpectedException, "Unable to retrieve an interpolant.");
return this->expressionAdapter->translateTerm(interpolant);
return this->expressionAdapter->translateExpression(interpolant);
#else
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without MathSAT support.");
#endif

8
src/solver/Z3SmtSolver.cpp

@ -14,7 +14,7 @@ namespace storm {
bool Z3SmtSolver::Z3ModelReference::getBooleanValue(std::string const& name) const {
#ifdef STORM_HAVE_Z3
z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name));
z3::expr z3ExprValuation = model.eval(z3Expr, true);
z3::expr z3ExprValuation = model.eval(z3Expr);
return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsBool();
#else
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support.");
@ -24,7 +24,7 @@ namespace storm {
int_fast64_t Z3SmtSolver::Z3ModelReference::getIntegerValue(std::string const& name) const {
#ifdef STORM_HAVE_Z3
z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createIntegerVariable(name));
z3::expr z3ExprValuation = model.eval(z3Expr, true);
z3::expr z3ExprValuation = model.eval(z3Expr);
return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsInt();
#else
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support.");
@ -34,7 +34,7 @@ namespace storm {
double Z3SmtSolver::Z3ModelReference::getDoubleValue(std::string const& name) const {
#ifdef STORM_HAVE_Z3
z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createDoubleVariable(name));
z3::expr z3ExprValuation = model.eval(z3Expr, true);
z3::expr z3ExprValuation = model.eval(z3Expr);
return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsDouble();
#else
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support.");
@ -50,7 +50,7 @@ namespace storm {
config.set("model", true);
context = std::unique_ptr<z3::context>(new z3::context(config));
solver = std::unique_ptr<z3::solver>(new z3::solver(*context));
expressionAdapter = std::unique_ptr<storm::adapters::Z3ExpressionAdapter>(new storm::adapters::Z3ExpressionAdapter(*context, std::map<std::string, z3::expr>(), true));
expressionAdapter = std::unique_ptr<storm::adapters::Z3ExpressionAdapter>(new storm::adapters::Z3ExpressionAdapter(*context, true));
}
Z3SmtSolver::~Z3SmtSolver() {

Loading…
Cancel
Save