diff --git a/src/mod_pomdp.cpp b/src/mod_pomdp.cpp index fb48fa7..d507824 100644 --- a/src/mod_pomdp.cpp +++ b/src/mod_pomdp.cpp @@ -1,5 +1,8 @@ #include "common.h" + +#include "pomdp/tracker.h" +#include "pomdp/qualitative_analysis.h" #include "pomdp/transformations.h" PYBIND11_MODULE(pomdp, m) { @@ -9,5 +12,8 @@ PYBIND11_MODULE(pomdp, m) { py::options options; options.disable_function_signatures(); #endif + define_tracker(m); + define_qualitative_policy_search(m, "Double"); + define_qualitative_policy_search_nt(m); define_transformations(m, "Double"); } diff --git a/src/pomdp/qualitative_analysis.cpp b/src/pomdp/qualitative_analysis.cpp new file mode 100644 index 0000000..d646851 --- /dev/null +++ b/src/pomdp/qualitative_analysis.cpp @@ -0,0 +1,46 @@ +#include "tracker.h" +#include "src/helpers.h" +#include +#include +#include +#include + +template using SparsePomdp = storm::models::sparse::Pomdp; + +template +std::shared_ptr> createWinningRegionSolver(SparsePomdp const& pomdp, storm::logic::Formula const& formula, storm::pomdp::MemlessSearchOptions const& options) { + + STORM_LOG_TRACE("Run qualitative preprocessing..."); + storm::analysis::QualitativeAnalysisOnGraphs qualitativeAnalysis(pomdp); + // After preprocessing, this might be done cheaper. + storm::storage::BitVector targetStates = qualitativeAnalysis.analyseProb1(formula.asProbabilityOperatorFormula()); + storm::storage::BitVector surelyNotAlmostSurelyReachTarget = qualitativeAnalysis.analyseProbSmaller1(formula.asProbabilityOperatorFormula()); + + storm::expressions::ExpressionManager expressionManager; + std::shared_ptr smtSolverFactory = std::make_shared(); + + return std::make_shared>(pomdp, targetStates, surelyNotAlmostSurelyReachTarget, smtSolverFactory, options); + +} + +template +void define_qualitative_policy_search(py::module& m, std::string const& vtSuffix) { + m.def(("create_iterative_qualitative_search_solver_" + vtSuffix).c_str(), &createWinningRegionSolver, "Create solver " ,py::arg("pomdp"), py::arg("formula"), py::arg("options")); + py::class_, std::shared_ptr>> mssq(m, ("IterativeQualitativeSearchSolver" + vtSuffix).c_str(), "Solver for POMDPs that solves qualitative queries"); + mssq.def("compute_winning_region", &storm::pomdp::MemlessStrategySearchQualitative::computeWinningRegion, py::arg("lookahead")); + mssq.def_property_readonly("last_winning_region", &storm::pomdp::MemlessStrategySearchQualitative::getLastWinningRegion, "get the last computed winning region"); + + py::class_> wrqi(m, ("BeliefSupportWinningRegionQueryInterface" + vtSuffix).c_str()); + wrqi.def(py::init const&, storm::pomdp::WinningRegion const&>(), py::arg("pomdp"), py::arg("BeliefSupportWinningRegion")); + wrqi.def("query_current_belief", &storm::pomdp::WinningRegionQueryInterface::isInWinningRegion, py::arg("current_belief")); + wrqi.def("query_action", &storm::pomdp::WinningRegionQueryInterface::staysInWinningRegion, py::arg("current_belief"), py::arg("action")); +} + +template void define_qualitative_policy_search(py::module& m, std::string const& vtSuffix); + +void define_qualitative_policy_search_nt(py::module& m) { + py::class_ mssqopts(m, "IterativeQualitativeSearchOptions", "Options for the IterativeQualitativeSearch"); + mssqopts.def(py::init<>()); + + py::class_ winningRegion(m, "BeliefSupportWinningRegion"); +} diff --git a/src/pomdp/qualitative_analysis.h b/src/pomdp/qualitative_analysis.h new file mode 100644 index 0000000..9b575a0 --- /dev/null +++ b/src/pomdp/qualitative_analysis.h @@ -0,0 +1,6 @@ +#pragma once +#include "common.h" + +template +void define_qualitative_policy_search(py::module& m, std::string const& vtSuffix); +void define_qualitative_policy_search_nt(py::module& m); \ No newline at end of file diff --git a/src/pomdp/tracker.cpp b/src/pomdp/tracker.cpp new file mode 100644 index 0000000..766d03c --- /dev/null +++ b/src/pomdp/tracker.cpp @@ -0,0 +1,15 @@ +#include "tracker.h" +#include "src/helpers.h" +#include + + +template using SparsePomdp = storm::models::sparse::Pomdp; +template using SparsePomdpTracker = storm::generator::BeliefSupportTracker; + + +void define_tracker(py::module& m) { + py::class_> tracker(m, "BeliefSupportTrackerDouble", "Tracker for BeliefSupports"); + tracker.def(py::init const&>(), py::arg("pomdp")); + tracker.def("get_current_belief_support", &SparsePomdpTracker::getCurrentBeliefSupport, "What is the support given the trace so far"); + tracker.def("track", &SparsePomdpTracker::track, py::arg("action"), py::arg("observation")); +} \ No newline at end of file diff --git a/src/pomdp/tracker.h b/src/pomdp/tracker.h new file mode 100644 index 0000000..9370d29 --- /dev/null +++ b/src/pomdp/tracker.h @@ -0,0 +1,4 @@ +#pragma once +#include "common.h" + +void define_tracker(py::module& m);