From 31c1357efab73266291018d5085de4d494faf5c2 Mon Sep 17 00:00:00 2001 From: David_Korzeniewski Date: Fri, 26 Sep 2014 13:14:36 +0200 Subject: [PATCH] alternative all sat callback Former-commit-id: 6fd7de7e511634c5940a1060108e7c9048870702 --- src/solver/SmtSolver.h | 23 +++++++++++++ src/solver/Z3SmtSolver.cpp | 67 ++++++++++++++++++++++++++++++++++++++ src/solver/Z3SmtSolver.h | 16 +++++++++ 3 files changed, 106 insertions(+) diff --git a/src/solver/SmtSolver.h b/src/solver/SmtSolver.h index 2417da40f..51235d28c 100644 --- a/src/solver/SmtSolver.h +++ b/src/solver/SmtSolver.h @@ -32,6 +32,12 @@ namespace storm { }; //! possible check results enum class CheckResult { SAT, UNSAT, UNKNOWN }; + + class ModelReference { + public: + virtual bool getBooleanValue(std::string const& name) const =0; + virtual int_fast64_t getIntegerValue(std::string const& name) const =0; + }; public: /*! * Constructs a new smt solver with the given options. @@ -158,6 +164,23 @@ namespace storm { throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support model generation."); } + /*! + * Performs all AllSat over the important atoms. Once a valuation of the important atoms such that the currently asserted formulas are satisfiable + * is found the callback is called with a reference to the model. The lifetime of that model is controlled by the solver implementation. It will most + * certainly be invalid after the callback returned. + * + * @param important A set of expressions over which to perform all sat. + * @param callback A function to call for each found valuation. + * + * @returns the number of valuations of the important atoms, such that the currently asserted formulas are satisfiable that where found + * + * @throws IllegalFunctionCallException if model generation is not configured for this solver + * @throws NotImplementedException if model generation is not implemented with this solver class + */ + virtual uint_fast64_t allSat(std::function callback, std::vector const& important) { + throw storm::exceptions::NotImplementedException("This subclass of SmtSolver does not support model generation."); + } //hack: switching the parameters is the only way to have overloading work with lambdas + /*! * Retrieves the unsat core of the last call to check() * diff --git a/src/solver/Z3SmtSolver.cpp b/src/solver/Z3SmtSolver.cpp index 45720c293..5f00947f5 100644 --- a/src/solver/Z3SmtSolver.cpp +++ b/src/solver/Z3SmtSolver.cpp @@ -3,6 +3,32 @@ namespace storm { namespace solver { +#ifdef STORM_HAVE_Z3 + Z3SmtSolver::Z3ModelReference::Z3ModelReference(z3::model &m, storm::adapters::Z3ExpressionAdapter &adapter) : m_model(m), m_adapter(adapter) { + + } +#endif + + bool Z3SmtSolver::Z3ModelReference::getBooleanValue(std::string const& name) const { +#ifdef STORM_HAVE_Z3 + z3::expr z3Expr = this->m_adapter.translateExpression(storm::expressions::Expression::createBooleanVariable(name)); + z3::expr z3ExprValuation = m_model.eval(z3Expr, true); + return this->m_adapter.translateExpression(z3ExprValuation).evaluateAsBool(); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + + int_fast64_t Z3SmtSolver::Z3ModelReference::getIntegerValue(std::string const& name) const { +#ifdef STORM_HAVE_Z3 + z3::expr z3Expr = this->m_adapter.translateExpression(storm::expressions::Expression::createIntegerVariable(name)); + z3::expr z3ExprValuation = m_model.eval(z3Expr, true); + return this->m_adapter.translateExpression(z3ExprValuation).evaluateAsInt(); +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + Z3SmtSolver::Z3SmtSolver(Options options) #ifdef STORM_HAVE_Z3 : m_context() @@ -240,6 +266,47 @@ namespace storm { LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); #endif } + + uint_fast64_t Z3SmtSolver::allSat(std::function callback, std::vector const& important) + { +#ifdef STORM_HAVE_Z3 + for (storm::expressions::Expression e : important) { + if (!e.isVariable()) { + throw storm::exceptions::InvalidArgumentException() << "The important expressions for AllSat must be atoms, i.e. variable expressions."; + } + } + + uint_fast64_t numModels = 0; + bool proceed = true; + + this->push(); + + while (proceed && this->check() == CheckResult::SAT) { + ++numModels; + z3::model m = this->m_solver.get_model(); + + z3::expr modelExpr = this->m_context.bool_val(true); + storm::expressions::SimpleValuation valuation; + + for (storm::expressions::Expression importantAtom : important) { + z3::expr z3ImportantAtom = this->m_adapter.translateExpression(importantAtom); + z3::expr z3ImportantAtomValuation = m.eval(z3ImportantAtom, true); + modelExpr = modelExpr && (z3ImportantAtom == z3ImportantAtomValuation); + } + + proceed = callback(Z3ModelReference(m, m_adapter)); + + this->m_solver.add(!modelExpr); + } + + this->pop(); + + return numModels; +#else + LOG_THROW(false, storm::exceptions::NotImplementedException, "StoRM is compiled without Z3 support."); +#endif + } + std::vector Z3SmtSolver::getUnsatAssumptions() { #ifdef STORM_HAVE_Z3 if (lastResult != SmtSolver::CheckResult::UNSAT) { diff --git a/src/solver/Z3SmtSolver.h b/src/solver/Z3SmtSolver.h index 0f2059443..1b5009957 100644 --- a/src/solver/Z3SmtSolver.h +++ b/src/solver/Z3SmtSolver.h @@ -13,6 +13,20 @@ namespace storm { namespace solver { class Z3SmtSolver : public SmtSolver { + public: + class Z3ModelReference : public SmtSolver::ModelReference { + public: +#ifdef STORM_HAVE_Z3 + Z3ModelReference(z3::model& m, storm::adapters::Z3ExpressionAdapter &adapter); +#endif + virtual bool getBooleanValue(std::string const& name) const override; + virtual int_fast64_t getIntegerValue(std::string const& name) const override; + private: +#ifdef STORM_HAVE_Z3 + z3::model &m_model; + storm::adapters::Z3ExpressionAdapter &m_adapter; +#endif + }; public: Z3SmtSolver(Options options = Options::ModelGeneration); virtual ~Z3SmtSolver(); @@ -39,6 +53,8 @@ namespace storm { virtual uint_fast64_t allSat(std::vector const& important, std::function callback) override; + virtual uint_fast64_t allSat(std::function callback, std::vector const& important) override; + virtual std::vector getUnsatAssumptions() override; protected: