From f54b5671ea9769c5c8eb9ae9c31961db82079e2f Mon Sep 17 00:00:00 2001 From: dehnert Date: Mon, 22 Dec 2014 15:45:08 +0100 Subject: [PATCH] Done refactoring MathSAT expression adapter. Former-commit-id: 6edb98b86c6d80ff7e72575100aa7da4fab879ba --- src/adapters/MathsatExpressionAdapter.h | 89 +++++++++++-------------- src/solver/MathsatSmtSolver.cpp | 20 +++--- src/solver/Z3SmtSolver.cpp | 8 +-- 3 files changed, 53 insertions(+), 64 deletions(-) diff --git a/src/adapters/MathsatExpressionAdapter.h b/src/adapters/MathsatExpressionAdapter.h index cb29232e1..f07b2cac5 100644 --- a/src/adapters/MathsatExpressionAdapter.h +++ b/src/adapters/MathsatExpressionAdapter.h @@ -34,7 +34,7 @@ namespace storm { * expressions and are not yet known to the adapter. * @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing). */ - MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map const& variableToDeclarationMap = std::map()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap) { + MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map const& variableToDeclarationMap = std::map()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) { // Intentionally left empty. } @@ -75,8 +75,7 @@ namespace storm { stack.push(msat_make_or(env, msat_make_not(env, leftResult), rightResult)); break; default: - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); } } @@ -110,8 +109,8 @@ namespace storm { case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; + default: + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); } } @@ -151,8 +150,8 @@ namespace storm { case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: stack.push(msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult)))); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."; + default: + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."); } } @@ -193,8 +192,8 @@ namespace storm { case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: stack.push(msat_make_not(env, childResult)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown boolean binary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."; + default: + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); } } @@ -208,52 +207,39 @@ namespace storm { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: stack.push(msat_make_times(env, msat_make_number(env, "-1"), childResult)); break; - default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unknown numerical unary operator: '" << static_cast(expression->getOperatorType()) << "'."; + default: + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); } } virtual void visit(expressions::VariableExpression const* expression) override { - if (createMathSatVariables) { - std::map variables; - - try { - variables = expression.getVariablesAndTypes(); - } - catch (storm::exceptions::InvalidTypeException* e) { - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with ambigious type while trying to autocreate solver variables: " << e); - } - - for (auto variableAndType : variables) { - if (this->variableToDeclarationMap.find(variableAndType.first) == this->variableToDeclarationMap.end()) { - switch (variableAndType.second) - { - case storm::expressions::ExpressionReturnType::Bool: - this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_bool_type(env)))); - break; - case storm::expressions::ExpressionReturnType::Int: - this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_integer_type(env)))); - break; - case storm::expressions::ExpressionReturnType::Double: - this->variableToDeclarationMap.insert(std::make_pair(variableAndType.first, msat_declare_function(env, variableAndType.first.c_str(), msat_get_rational_type(env)))); - break; - default: - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable with unknown type while trying to autocreate solver variables: " << variableAndType.first); - break; - } - } - } - } + std::map::iterator stringVariablePair = variableToDeclarationMap.find(expression->getVariableName()); + msat_decl result; - STORM_LOG_THROW(variableToDeclarationMap.count(expression->getVariableName()) != 0, storm::exceptions::InvalidArgumentException, "Variable '" << expression->getVariableName() << "' is unknown."); - //LOG4CPLUS_TRACE(logger, "Variable "<getVariableName()); - //char* repr = msat_decl_repr(variableToDeclMap.at(expression->getVariableName())); - //LOG4CPLUS_TRACE(logger, "Decl: "<getVariableName()))) { - STORM_LOG_WARN("Encountered an invalid MathSAT declaration."); - } - stack.push(msat_make_constant(env, variableToDeclarationMap.at(expression->getVariableName()))); + if (stringVariablePair == variableToDeclarationMap.end() && createVariables) { + std::pair::iterator, bool> iteratorAndFlag; + switch (expression->getReturnType()) { + case storm::expressions::ExpressionReturnType::Bool: + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_bool_type(env)))); + result = iteratorAndFlag.first->second; + break; + case storm::expressions::ExpressionReturnType::Int: + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_integer_type(env)))); + result = iteratorAndFlag.first->second; + break; + case storm::expressions::ExpressionReturnType::Double: + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_rational_type(env)))); + result = iteratorAndFlag.first->second; + break; + default: + STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables."); + } + } else { + STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'."); + result = stringVariablePair->second; + } + + stack.push(msat_make_constant(env, result)); } storm::expressions::Expression translateExpression(msat_term const& term) { @@ -319,6 +305,9 @@ namespace storm { // A mapping of variable names to their declaration in the MathSAT environment. std::map variableToDeclarationMap; + + // A flag indicating whether variables are supposed to be created if they are not already known to the adapter. + bool createVariables; }; #endif } // namespace adapters diff --git a/src/solver/MathsatSmtSolver.cpp b/src/solver/MathsatSmtSolver.cpp index 809fd26d9..3c0778b3d 100644 --- a/src/solver/MathsatSmtSolver.cpp +++ b/src/solver/MathsatSmtSolver.cpp @@ -38,28 +38,28 @@ namespace storm { } #endif bool MathsatSmtSolver::MathsatModelReference::getBooleanValue(std::string const& name) const { - msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false); + msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); msat_term msatValue = msat_get_model_value(env, msatVariable); STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); - storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue); + storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); STORM_LOG_THROW(value.hasBooleanReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve boolean value of non-boolean variable '" << name << "'."); return value.evaluateAsBool(); } int_fast64_t MathsatSmtSolver::MathsatModelReference::getIntegerValue(std::string const& name) const { - msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false); + msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); msat_term msatValue = msat_get_model_value(env, msatVariable); STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); - storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue); + storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve integer value of non-integer variable '" << name << "'."); return value.evaluateAsInt(); } double MathsatSmtSolver::MathsatModelReference::getDoubleValue(std::string const& name) const { - msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name), false); + msat_term msatVariable = expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); msat_term msatValue = msat_get_model_value(env, msatVariable); STORM_LOG_THROW(!MSAT_ERROR_TERM(msatValue), storm::exceptions::UnexpectedException, "Unable to retrieve value of variable in model. This could be caused by calls to the solver between checking for satisfiability and model retrieval."); - storm::expressions::Expression value = expressionAdapter.translateTerm(msatValue); + storm::expressions::Expression value = expressionAdapter.translateExpression(msatValue); STORM_LOG_THROW(value.hasIntegralReturnType(), storm::exceptions::InvalidArgumentException, "Unable to retrieve double value of non-double variable '" << name << "'."); return value.evaluateAsDouble(); } @@ -128,7 +128,7 @@ namespace storm { void MathsatSmtSolver::add(storm::expressions::Expression const& e) { #ifdef STORM_HAVE_MSAT - msat_assert_formula(env, expressionAdapter->translateExpression(e, true)); + msat_assert_formula(env, expressionAdapter->translateExpression(e)); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without MathSAT support."); #endif @@ -241,7 +241,7 @@ namespace storm { msat_term t, v; msat_model_iterator_next(modelIterator, &t, &v); - storm::expressions::Expression variableInterpretation = this->expressionAdapter->translateTerm(v); + storm::expressions::Expression variableInterpretation = this->expressionAdapter->translateExpression(v); char* name = msat_decl_get_name(msat_term_get_decl(t)); switch (variableInterpretation.getReturnType()) { @@ -411,7 +411,7 @@ namespace storm { unsatAssumptions.reserve(numUnsatAssumpations); for (unsigned int i = 0; i < numUnsatAssumpations; ++i) { - unsatAssumptions.push_back(this->expressionAdapter->translateTerm(msatUnsatAssumptions[i])); + unsatAssumptions.push_back(this->expressionAdapter->translateExpression(msatUnsatAssumptions[i])); } return unsatAssumptions; @@ -450,7 +450,7 @@ namespace storm { STORM_LOG_THROW(!MSAT_ERROR_TERM(interpolant), storm::exceptions::UnexpectedException, "Unable to retrieve an interpolant."); - return this->expressionAdapter->translateTerm(interpolant); + return this->expressionAdapter->translateExpression(interpolant); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without MathSAT support."); #endif diff --git a/src/solver/Z3SmtSolver.cpp b/src/solver/Z3SmtSolver.cpp index 1de65eb74..09daf9ae0 100644 --- a/src/solver/Z3SmtSolver.cpp +++ b/src/solver/Z3SmtSolver.cpp @@ -14,7 +14,7 @@ namespace storm { bool Z3SmtSolver::Z3ModelReference::getBooleanValue(std::string const& name) const { #ifdef STORM_HAVE_Z3 z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr, true); + z3::expr z3ExprValuation = model.eval(z3Expr); return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsBool(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); @@ -24,7 +24,7 @@ namespace storm { int_fast64_t Z3SmtSolver::Z3ModelReference::getIntegerValue(std::string const& name) const { #ifdef STORM_HAVE_Z3 z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createIntegerVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr, true); + z3::expr z3ExprValuation = model.eval(z3Expr); return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsInt(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); @@ -34,7 +34,7 @@ namespace storm { double Z3SmtSolver::Z3ModelReference::getDoubleValue(std::string const& name) const { #ifdef STORM_HAVE_Z3 z3::expr z3Expr = this->expressionAdapter.translateExpression(storm::expressions::Expression::createDoubleVariable(name)); - z3::expr z3ExprValuation = model.eval(z3Expr, true); + z3::expr z3ExprValuation = model.eval(z3Expr); return this->expressionAdapter.translateExpression(z3ExprValuation).evaluateAsDouble(); #else STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "StoRM is compiled without Z3 support."); @@ -50,7 +50,7 @@ namespace storm { config.set("model", true); context = std::unique_ptr(new z3::context(config)); solver = std::unique_ptr(new z3::solver(*context)); - expressionAdapter = std::unique_ptr(new storm::adapters::Z3ExpressionAdapter(*context, std::map(), true)); + expressionAdapter = std::unique_ptr(new storm::adapters::Z3ExpressionAdapter(*context, true)); } Z3SmtSolver::~Z3SmtSolver() {