diff --git a/src/solver/MathSatSmtSolver.cpp b/src/solver/MathSatSmtSolver.cpp index ac8c6622a..7d8ee19d1 100644 --- a/src/solver/MathSatSmtSolver.cpp +++ b/src/solver/MathSatSmtSolver.cpp @@ -4,31 +4,6 @@ namespace storm { namespace solver { -#ifdef STORM_HAVE_MSAT - MathSatSmtSolver::MathSatModelReference::Z3ModelReference(z3::model &m, storm::adapters::Z3ExpressionAdapter &adapter) : m_model(m), m_adapter(adapter) { - - } -#endif - - bool MathSatSmtSolver::Z3ModelReference::getBooleanValue(std::string const& name) const { -#ifdef STORM_HAVE_MSAT - z3::expr z3Expr = this->m_adapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); - z3::expr z3ExprValuation = m_model.eval(z3Expr, true); - return this->m_adapter.translateExpression(z3ExprValuation).evaluateAsBool(); -#else - LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); -#endif - } - - int_fast64_t MathSatSmtSolver::Z3ModelReference::getIntegerValue(std::string const& name) const { -#ifdef STORM_HAVE_MSAT - z3::expr z3Expr = this->m_adapter.translateExpression(storm::expressions::Expression::createIntegerVariable(name)); - z3::expr z3ExprValuation = m_model.eval(z3Expr, true); - return this->m_adapter.translateExpression(z3ExprValuation).evaluateAsInt(); -#else - LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); -#endif - } MathSatSmtSolver::MathSatSmtSolver(Options options) #ifdef STORM_HAVE_MSAT @@ -38,14 +13,13 @@ namespace storm { { #ifdef STORM_HAVE_MSAT m_cfg = msat_create_config(); - + if (static_cast(options)& static_cast(Options::InterpolantComputation)) { - msat_res = msat_set_option(m_cfg, "interpolation", "true"); - if (msat_res != 0) { - LOG4CPLUS_WARN(logger, "MathSAT returned an error!"); - } + msat_set_option(m_cfg, "interpolation", "true"); } m_env = msat_create_env(m_cfg); + + m_adapter = new storm::adapters::MathSatExpressionAdapter(m_env, variableToDeclMap); #endif } MathSatSmtSolver::~MathSatSmtSolver() { @@ -94,7 +68,7 @@ namespace storm { void MathSatSmtSolver::assertExpression(storm::expressions::Expression const& e) { #ifdef STORM_HAVE_MSAT - msat_assert_formula(m_env, m_adapter.translateExpression(e, true)); + msat_assert_formula(m_env, m_adapter->translateExpression(e, true)); #else LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); #endif @@ -129,7 +103,7 @@ namespace storm { mathSatAssumptions.reserve(assumptions.size()); for (storm::expressions::Expression assumption : assumptions) { - mathSatAssumptions.push_back(this->m_adapter.translateExpression(assumption)); + mathSatAssumptions.push_back(this->m_adapter->translateExpression(assumption)); } switch (msat_solve_with_assumptions(m_env, mathSatAssumptions.data(), mathSatAssumptions.size())) { @@ -157,7 +131,7 @@ namespace storm { mathSatAssumptions.reserve(assumptions.size()); for (storm::expressions::Expression assumption : assumptions) { - mathSatAssumptions.push_back(this->m_adapter.translateExpression(assumption)); + mathSatAssumptions.push_back(this->m_adapter->translateExpression(assumption)); } switch (msat_solve_with_assumptions(m_env, mathSatAssumptions.data(), mathSatAssumptions.size())) { @@ -199,7 +173,7 @@ namespace storm { msat_term t, v; msat_model_iterator_next(model, &t, &v); - storm::expressions::Expression var_i_interp = this->m_adapter.translateTerm(v); + storm::expressions::Expression var_i_interp = this->m_adapter->translateTerm(v); char* name = msat_decl_get_name(msat_term_get_decl(t)); switch (var_i_interp.getReturnType()) { @@ -244,18 +218,27 @@ namespace storm { #ifdef STORM_HAVE_MSAT struct AllsatValuationsCallbackUserData { - msat_env env; + msat_env &env; storm::adapters::MathSatExpressionAdapter &adapter; - storm::expressions::SimpleValuation& valuation; - uint_fast64_t n; + std::function &callback; }; int allsatValuationsCallback(msat_term *model, int size, void *user_data) { AllsatValuationsCallbackUserData* user = reinterpret_cast(user_data); - ++n; + + storm::expressions::SimpleValuation valuation; for (int i = 0; i < size; ++i) { - /// + bool currentTermValue = true; + msat_term currentTerm = model[i]; + if (msat_term_is_not(user->env, currentTerm)) { + currentTerm = msat_term_get_arg(currentTerm, 0); + currentTermValue = false; + } + char* name = msat_decl_get_name(msat_term_get_decl(currentTerm)); + std::string name_str(name); + valuation.addBooleanIdentifier(name_str, currentTermValue); + msat_free(name); } } #endif @@ -264,13 +247,21 @@ namespace storm { uint_fast64_t MathSatSmtSolver::allSat(std::vector const& important, std::function callback) { #ifdef STORM_HAVE_MSAT + std::vector msatImportant; + msatImportant.reserve(important.size()); + for (storm::expressions::Expression e : important) { if (!e.isVariable()) { throw storm::exceptions::InvalidArgumentException() << "The important expressions for AllSat must be atoms, i.e. variable expressions."; } + msatImportant.push_back(m_adapter->translateExpression(e)); } - + AllsatValuationsCallbackUserData allSatUserData; + allSatUserData.adapter = m_adapter; + allSatUserData.env = m_env; + allSatUserData.callback = callback; + int numModels = msat_all_sat(m_env, msatImportant.data(), msatImportant.size(), &allsatValuationsCallback, &allSatUserData); return numModels; #else @@ -281,13 +272,7 @@ namespace storm { uint_fast64_t MathSatSmtSolver::allSat(std::function callback, std::vector const& important) { #ifdef STORM_HAVE_MSAT - for (storm::expressions::Expression e : important) { - if (!e.isVariable()) { - throw storm::exceptions::InvalidArgumentException() << "The important expressions for AllSat must be atoms, i.e. variable expressions."; - } - } - - return numModels; + LOG_THROW(false, storm::exceptions::NotImplementedException, "Not Implemented."); #else LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); #endif @@ -309,12 +294,55 @@ namespace storm { unsatAssumptions.reserve(numUnsatAssumpations); for (unsigned int i = 0; i < numUnsatAssumpations; ++i) { - unsatAssumptions.push_back(this->m_adapter.translateTerm(msatUnsatAssumptions[i])); + unsatAssumptions.push_back(this->m_adapter->translateTerm(msatUnsatAssumptions[i])); } return unsatAssumptions; #else LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); +#endif + } + + void MathSatSmtSolver::setInterpolationGroup(uint_fast64_t group) { +#ifdef STORM_HAVE_MSAT + auto groupIter = this->interpolationGroups.find(group); + if( groupIter == this->interpolationGroups.end() ) { + int newGroup = msat_create_itp_group(m_env); + auto insertResult = this->interpolationGroups.insert(std::make_pair(group, newGroup)); + if (!insertResult.second) { + throw storm::exceptions::InvalidStateException() << "Internal error in MathSAT wrapper: Unable to insert newly created interpolation group."; + } + groupIter = insertResult.first; + } + msat_set_itp_group(m_env, groupIter->second); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); +#endif + } + + storm::expressions::Expression MathSatSmtSolver::getInterpolant(std::vector groupsA) { +#ifdef STORM_HAVE_MSAT + if (lastResult != SmtSolver::CheckResult::UNSAT) { + throw storm::exceptions::InvalidStateException() << "getInterpolant was called but last state is not unsat."; + } + if (lastCheckAssumptions) { + throw storm::exceptions::InvalidStateException() << "getInterpolant was called but last check had assumptions."; + } + + std::vector msatInterpolationGroupsA; + msatInterpolationGroupsA.reserve(groupsA.size()); + for (auto groupOfA : groupsA) { + auto groupIter = this->interpolationGroups.find(groupOfA); + if (groupIter == this->interpolationGroups.end()) { + throw storm::exceptions::InvalidArgumentException() << "Requested interpolant for non existing interpolation group " << groupOfA; + } + msatInterpolationGroupsA.push_back(groupIter->second); + } + msat_term interpolant = msat_get_interpolant(m_env, msatInterpolationGroupsA.data(), msatInterpolationGroupsA.size()); + + return this->m_adapter->translateTerm(interpolant); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without MathSat support."); #endif } } diff --git a/src/solver/MathSatSmtSolver.h b/src/solver/MathSatSmtSolver.h index f6546b768..f2f73c460 100644 --- a/src/solver/MathSatSmtSolver.h +++ b/src/solver/MathSatSmtSolver.h @@ -4,6 +4,7 @@ #include "storm-config.h" #include "src/solver/SmtSolver.h" #include "src/adapters/MathSatExpressionAdapter.h" +#include #ifndef STORM_HAVE_MSAT #define STORM_HAVE_MSAT @@ -73,10 +74,13 @@ namespace storm { #ifdef STORM_HAVE_MSAT msat_config m_cfg; msat_env m_env; - storm::adapters::MathSatExpressionAdapter m_adapter; + storm::adapters::MathSatExpressionAdapter *m_adapter; bool lastCheckAssumptions; CheckResult lastResult; + typedef boost::container::flat_map InterpolationGroupMap; + InterpolationGroupMap interpolationGroups; + std::map variableToDeclMap; #endif }; }