diff --git a/src/pars/pla.cpp b/src/pars/pla.cpp index 1e786db..e2fe24e 100644 --- a/src/pars/pla.cpp +++ b/src/pars/pla.cpp @@ -11,8 +11,8 @@ std::shared_ptr createRegionChecker(std::shared_ptr(model, storm::api::createTask(formula, true)); } -storm::modelchecker::RegionResult checkRegion(std::shared_ptr& checker, Region& region, storm::modelchecker::RegionResult initialResult, bool sampleVertices) { - return checker->analyzeRegion(region, initialResult, sampleVertices); +storm::modelchecker::RegionResult checkRegion(std::shared_ptr& checker, Region const& region, storm::modelchecker::RegionResultHypothesis const& hypothesis, storm::modelchecker::RegionResult const& initialResult, bool sampleVertices) { + return checker->analyzeRegion(region, hypothesis, initialResult, sampleVertices); } std::set gatherDerivatives(storm::models::sparse::Dtmc const& model, carl::Variable const& var) { @@ -44,6 +44,14 @@ void define_pla(py::module& m) { .def("__str__", &streamToString) ; + // RegionResultHypothesis + py::enum_(m, "RegionResultHypothesis", "Hypothesis for the result of a parameter region") + .value("UNKNOWN", storm::modelchecker::RegionResultHypothesis::Unknown) + .value("ALLSAT", storm::modelchecker::RegionResultHypothesis::AllSat) + .value("ALLVIOLATED", storm::modelchecker::RegionResultHypothesis::AllViolated) + .def("__str__", &streamToString) + ; + // Region py::class_>(m, "ParameterRegion", "Parameter region") .def("__init__", [](Region &instance, std::string const& regionString, std::set const& variables) -> void { @@ -61,7 +69,7 @@ void define_pla(py::module& m) { auto tmp = storm::api::initializeParameterLiftingRegionModelChecker(model, task); new (&instance) std::unique_ptr(tmp); }, py::arg("model"), py::arg("task")*/ - .def("check_region", &checkRegion, "Check region", py::arg("region"), py::arg("initialResult") = storm::modelchecker::RegionResult::Unknown, py::arg("sampleVertices") = false) + .def("check_region", &checkRegion, "Check region", py::arg("region"), py::arg("hypothesis") = storm::modelchecker::RegionResultHypothesis::Unknown, py::arg("initialResult") = storm::modelchecker::RegionResult::Unknown, py::arg("sampleVertices") = false) ; m.def("create_region_checker", &createRegionChecker, "Create region checker", py::arg("model"), py::arg("formula")); diff --git a/tests/pars/test_pla.py b/tests/pars/test_pla.py index 62b8030..1fbcf1f 100644 --- a/tests/pars/test_pla.py +++ b/tests/pars/test_pla.py @@ -21,7 +21,7 @@ class TestModelChecking: result = checker.check_region(region) assert result == stormpy.pars.RegionResult.ALLSAT region = stormpy.pars.ParameterRegion("0.4<=pL<=0.65,0.75<=pK<=0.95", parameters) - result = checker.check_region(region, stormpy.pars.RegionResult.UNKNOWN, True) + result = checker.check_region(region, stormpy.pars.RegionResultHypothesis.UNKNOWN, stormpy.pars.RegionResult.UNKNOWN, True) assert result == stormpy.pars.RegionResult.EXISTSBOTH region = stormpy.pars.ParameterRegion("0.1<=pL<=0.73,0.2<=pK<=0.715", parameters) result = checker.check_region(region)