diff --git a/src/pars/pla.cpp b/src/pars/pla.cpp index 0ecd192..92ab461 100644 --- a/src/pars/pla.cpp +++ b/src/pars/pla.cpp @@ -2,21 +2,22 @@ #include "src/helpers.h" #include "storm/api/storm.h" + typedef storm::modelchecker::SparseDtmcParameterLiftingModelChecker, double> SparseDtmcRegionChecker; typedef storm::modelchecker::RegionModelChecker RegionModelChecker; typedef storm::storage::ParameterRegion Region; // Thin wrappers -std::shared_ptr createRegionChecker(std::shared_ptr> const& model, std::shared_ptr const& formula) { - return storm::api::initializeParameterLiftingRegionModelChecker(model, storm::api::createTask(formula, true)); +std::shared_ptr createRegionChecker(storm::Environment const& env, std::shared_ptr> const& model, std::shared_ptr const& formula) { + return storm::api::initializeParameterLiftingRegionModelChecker(env, model, storm::api::createTask(formula, true)); } -storm::modelchecker::RegionResult checkRegion(std::shared_ptr& checker, Region const& region, storm::modelchecker::RegionResultHypothesis const& hypothesis, storm::modelchecker::RegionResult const& initialResult, bool sampleVertices) { - return checker->analyzeRegion(region, hypothesis, initialResult, sampleVertices); +storm::modelchecker::RegionResult checkRegion(std::shared_ptr& checker, storm::Environment const& env, Region const& region, storm::modelchecker::RegionResultHypothesis const& hypothesis, storm::modelchecker::RegionResult const& initialResult, bool sampleVertices) { + return checker->analyzeRegion(env, region, hypothesis, initialResult, sampleVertices); } -storm::RationalFunction getBound(std::shared_ptr& checker, Region const& region, bool maximise) { - return checker->getBoundAtInitState(region, maximise ? storm::solver::OptimizationDirection::Maximize : storm::solver::OptimizationDirection::Minimize); +storm::RationalFunction getBound(std::shared_ptr& checker, storm::Environment const& env, Region const& region, bool maximise) { + return checker->getBoundAtInitState(env, region, maximise ? storm::solver::OptimizationDirection::Maximize : storm::solver::OptimizationDirection::Minimize); } @@ -74,11 +75,11 @@ void define_pla(py::module& m) { auto tmp = storm::api::initializeParameterLiftingRegionModelChecker(model, task); new (&instance) std::unique_ptr(tmp); }, py::arg("model"), py::arg("task")*/ - .def("check_region", &checkRegion, "Check region", py::arg("region"), py::arg("hypothesis") = storm::modelchecker::RegionResultHypothesis::Unknown, py::arg("initialResult") = storm::modelchecker::RegionResult::Unknown, py::arg("sampleVertices") = false) - .def("get_bound", &getBound, "Get bound", py::arg("region"), py::arg("maximise")= true); + .def("check_region", &checkRegion, "Check region", py::arg("environment"), py::arg("region"), py::arg("hypothesis") = storm::modelchecker::RegionResultHypothesis::Unknown, py::arg("initialResult") = storm::modelchecker::RegionResult::Unknown, py::arg("sampleVertices") = false) + .def("get_bound", &getBound, "Get bound", py::arg("environment"), py::arg("region"), py::arg("maximise")= true); ; - m.def("create_region_checker", &createRegionChecker, "Create region checker", py::arg("model"), py::arg("formula")); + m.def("create_region_checker", &createRegionChecker, "Create region checker", py::arg("environment"), py::arg("model"), py::arg("formula")); //m.def("is_parameter_lifting_sound", &storm::utility::parameterlifting::validateParameterLiftingSound, "Check if parameter lifting is sound", py::arg("model"), py::arg("formula")); m.def("gather_derivatives", &gatherDerivatives, "Gather all derivatives of transition probabilities", py::arg("model"), py::arg("var")); } diff --git a/tests/pars/test_pla.py b/tests/pars/test_pla.py index 314ce14..6336527 100644 --- a/tests/pars/test_pla.py +++ b/tests/pars/test_pla.py @@ -16,16 +16,17 @@ class TestPLA: assert model.nr_transitions == 803 assert model.model_type == stormpy.ModelType.DTMC assert model.has_parameters - checker = stormpy.pars.create_region_checker(model, formulas[0].raw_formula) + env = stormpy.Environment() + checker = stormpy.pars.create_region_checker(env, model, formulas[0].raw_formula) parameters = model.collect_probability_parameters() assert len(parameters) == 2 region = stormpy.pars.ParameterRegion("0.7<=pL<=0.9,0.75<=pK<=0.95", parameters) - result = checker.check_region(region) + result = checker.check_region(env, region) assert result == stormpy.pars.RegionResult.ALLSAT region = stormpy.pars.ParameterRegion("0.4<=pL<=0.65,0.75<=pK<=0.95", parameters) - result = checker.check_region(region, stormpy.pars.RegionResultHypothesis.UNKNOWN, + result = checker.check_region(env, region, stormpy.pars.RegionResultHypothesis.UNKNOWN, stormpy.pars.RegionResult.UNKNOWN, True) assert result == stormpy.pars.RegionResult.EXISTSBOTH region = stormpy.pars.ParameterRegion("0.1<=pL<=0.73,0.2<=pK<=0.715", parameters) - result = checker.check_region(region) + result = checker.check_region(env, region) assert result == stormpy.pars.RegionResult.ALLVIOLATED