From 4418422ea83afb312ca0e2f66b366b64e45109f8 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Tue, 3 Dec 2019 17:30:49 +0100 Subject: [PATCH] merge -- but code is not working atm --- .../MemlessStrategySearchQualitative.cpp | 2 +- .../MemlessStrategySearchQualitative.h | 79 +++++++++++++++++++ .../ApproximatePOMDPModelchecker.cpp | 6 +- src/storm/adapters/Z3ExpressionAdapter.cpp | 3 +- src/storm/solver/SmtSolver.h | 4 + src/storm/solver/SmtlibSmtSolver.cpp | 5 ++ src/storm/solver/SmtlibSmtSolver.h | 2 +- src/storm/solver/Z3SmtSolver.cpp | 13 ++- src/storm/solver/Z3SmtSolver.h | 1 + 9 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp index 0c46fd0d3..2ed1230e0 100644 --- a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.cpp @@ -61,7 +61,7 @@ namespace storm { smtSolver->add(!pathVars[state][0]); } - if (surelyReachSinkStates.at(state)) { + if (surelyReachSinkStates.get(state)) { smtSolver->add(!reachVars[state]); } else { diff --git a/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h new file mode 100644 index 000000000..ca1a84aa6 --- /dev/null +++ b/src/storm-pomdp/analysis/MemlessStrategySearchQualitative.h @@ -0,0 +1,79 @@ +#include +#include "storm/storage/expressions/Expressions.h" +#include "storm/solver/SmtSolver.h" +#include "storm/models/sparse/Pomdp.h" +#include "storm/utility/solver.h" + +namespace storm { +namespace pomdp { + + template + class MemlessStrategySearchQualitative { + // Implements an extension to the Chatterjee, Chmelik, Davies (AAAI-16) paper. + + + public: + MemlessStrategySearchQualitative(storm::models::sparse::Pomdp const& pomdp, + std::set const& targetObservationSet, + std::shared_ptr& smtSolverFactory) : + pomdp(pomdp), + targetObservations(targetObservationSet) { + this->expressionManager = std::make_shared(); + smtSolver = smtSolverFactory->create(*expressionManager); + + } + + void setSurelyReachSinkStates(storm::storage::BitVector const& surelyReachSink) { + surelyReachSinkStates = surelyReachSink; + } + + void analyze(uint64_t k) { + if (k < maxK) { + initialize(k); + } + std::cout << smtSolver->getSmtLibString() << std::endl; + for (uint64_t state : pomdp.getInitialStates()) { + smtSolver->add(reachVars[state]); + } + auto result = smtSolver->check(); + switch(result) { + case storm::solver::SmtSolver::CheckResult::Sat: + std::cout << std::endl << "Satisfying assignment: " << std::endl << smtSolver->getModelAsValuation().toString(true) << std::endl; + + case storm::solver::SmtSolver::CheckResult::Unsat: + // std::cout << std::endl << "Unsatisfiability core: {" << std::endl; + // for (auto const& expr : solver->getUnsatCore()) { + // std::cout << "\t " << expr << std::endl; + // } + // std::cout << "}" << std::endl; + + default: + std::cout<< "oops." << std::endl; + // STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "SMT solver yielded an unexpected result"); + } + //std::cout << "get model:" << std::endl; + //std::cout << smtSolver->getModel().toString() << std::endl; + } + + + private: + void initialize(uint64_t k); + + std::unique_ptr smtSolver; + storm::models::sparse::Pomdp const& pomdp; + std::shared_ptr expressionManager; + uint64_t maxK = -1; + + std::set targetObservations; + storm::storage::BitVector surelyReachSinkStates; + + std::vector> statesPerObservation; + std::vector> actionSelectionVars; // A_{z,a} + std::vector reachVars; + std::vector> pathVars; + + + + }; +} +} diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index 37a29f9b1..3cd362e58 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -10,6 +10,7 @@ #include "storm/modelchecker/results/CheckResult.h" #include "storm/modelchecker/results/ExplicitQualitativeCheckResult.h" #include "storm/modelchecker/results/ExplicitQuantitativeCheckResult.h" +#include "storm/models/sparse/StandardRewardModel.h" #include "storm/api/properties.h" #include "storm/api/export.h" #include "storm-parsers/api/storm-parsers.h" @@ -548,7 +549,8 @@ namespace storm { template std::unique_ptr> ApproximatePOMDPModelchecker::computeReachabilityReward(storm::models::sparse::Pomdp const &pomdp, - std::set const &targetObservations, bool min, uint64_t gridResolution) { + std::set const &targetObservations, bool min, + uint64_t gridResolution) { return computeReachability(pomdp, targetObservations, min, gridResolution, true); } @@ -1088,8 +1090,6 @@ namespace storm { class ApproximatePOMDPModelchecker; #ifdef STORM_HAVE_CARL - - //template class ApproximatePOMDPModelchecker; template class ApproximatePOMDPModelchecker; diff --git a/src/storm/adapters/Z3ExpressionAdapter.cpp b/src/storm/adapters/Z3ExpressionAdapter.cpp index bb9afb6f3..437f003b9 100644 --- a/src/storm/adapters/Z3ExpressionAdapter.cpp +++ b/src/storm/adapters/Z3ExpressionAdapter.cpp @@ -37,8 +37,7 @@ namespace storm { result = result && assertion; } additionalAssertions.clear(); - - return result; + return result.simplify(); } z3::expr Z3ExpressionAdapter::translateExpression(storm::expressions::Variable const& variable) { diff --git a/src/storm/solver/SmtSolver.h b/src/storm/solver/SmtSolver.h index 7ee3896bd..4077c244a 100644 --- a/src/storm/solver/SmtSolver.h +++ b/src/storm/solver/SmtSolver.h @@ -23,6 +23,8 @@ namespace storm { public: //! possible check results enum class CheckResult { Sat, Unsat, Unknown }; + + /*! * The base class for all model references. They are used to provide a lightweight method of accessing the @@ -48,6 +50,8 @@ namespace storm { * @return The expression manager associated with this model reference. */ storm::expressions::ExpressionManager const& getManager() const; + + virtual std::string toString() const = 0; private: // The expression manager responsible for the variables whose value can be requested via this model diff --git a/src/storm/solver/SmtlibSmtSolver.cpp b/src/storm/solver/SmtlibSmtSolver.cpp index aad46d81e..1ef7fda76 100644 --- a/src/storm/solver/SmtlibSmtSolver.cpp +++ b/src/storm/solver/SmtlibSmtSolver.cpp @@ -40,6 +40,11 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "functionality not (yet) implemented"); } + std::string SmtlibSmtSolver::SmtlibModelReference::toString() const { + STORM_LOG_THROW(false, storm::exceptions::NotImplementedException, "functionality not (yet) implemented"); + } + + SmtlibSmtSolver::SmtlibSmtSolver(storm::expressions::ExpressionManager& manager, bool useCarlExpressions) : SmtSolver(manager), isCommandFileOpen(false), expressionAdapter(nullptr), useCarlExpressions(useCarlExpressions) { #ifndef STORM_HAVE_CARL STORM_LOG_THROW(!useCarlExpressions, storm::exceptions::IllegalArgumentException, "Tried to use carl expressions but storm is not linked with CARL"); diff --git a/src/storm/solver/SmtlibSmtSolver.h b/src/storm/solver/SmtlibSmtSolver.h index 6ae64eb92..cd31e58d5 100644 --- a/src/storm/solver/SmtlibSmtSolver.h +++ b/src/storm/solver/SmtlibSmtSolver.h @@ -26,7 +26,7 @@ namespace storm { virtual bool getBooleanValue(storm::expressions::Variable const& variable) const override; virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const override; virtual double getRationalValue(storm::expressions::Variable const& variable) const override; - + virtual std::string toString() const override; }; public: diff --git a/src/storm/solver/Z3SmtSolver.cpp b/src/storm/solver/Z3SmtSolver.cpp index 13579a172..7d3745fa7 100644 --- a/src/storm/solver/Z3SmtSolver.cpp +++ b/src/storm/solver/Z3SmtSolver.cpp @@ -44,7 +44,18 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Storm is compiled without Z3 support."); #endif } - + + std::string Z3SmtSolver::Z3ModelReference::toString() const { +#ifdef STORM_HAVE_Z3 + std::stringstream sstr; + sstr << model; + return sstr.str(); +#else + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Storm is compiled without Z3 support."); +#endif + } + + Z3SmtSolver::Z3SmtSolver(storm::expressions::ExpressionManager& manager) : SmtSolver(manager) #ifdef STORM_HAVE_Z3 , context(nullptr), solver(nullptr), expressionAdapter(nullptr), lastCheckAssumptions(false), lastResult(CheckResult::Unknown) diff --git a/src/storm/solver/Z3SmtSolver.h b/src/storm/solver/Z3SmtSolver.h index 6616e0292..fe92299a7 100644 --- a/src/storm/solver/Z3SmtSolver.h +++ b/src/storm/solver/Z3SmtSolver.h @@ -22,6 +22,7 @@ namespace storm { virtual bool getBooleanValue(storm::expressions::Variable const& variable) const override; virtual int_fast64_t getIntegerValue(storm::expressions::Variable const& variable) const override; virtual double getRationalValue(storm::expressions::Variable const& variable) const override; + virtual std::string toString() const override; private: #ifdef STORM_HAVE_Z3