diff --git a/src/core/modelchecking.cpp b/src/core/modelchecking.cpp index dd00a28..84488c2 100644 --- a/src/core/modelchecking.cpp +++ b/src/core/modelchecking.cpp @@ -56,6 +56,10 @@ std::shared_ptr parametricModelChecking(std::shared_ptrconstraintsGraphPreserving = constraintCollector.getGraphPreservingConstraints(); return result; } +// Thin wrapper for computing prob01 states +std::pair computeProb01(storm::models::sparse::Dtmc model, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { + return storm::utility::graph::performProb01(model, phiStates, psiStates); +} // Define python bindings void define_modelchecking(py::module& m) { @@ -64,6 +68,7 @@ void define_modelchecking(py::module& m) { m.def("_model_checking", &modelChecking, "Perform model checking", py::arg("model"), py::arg("formula")); m.def("model_checking_all", &modelCheckingAll, "Perform model checking for all states", py::arg("model"), py::arg("formula")); m.def("_parametric_model_checking", ¶metricModelChecking, "Perform parametric model checking", py::arg("model"), py::arg("formula")); + m.def("compute_prob01states", &computeProb01, "Compute prob-0-1 states", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); // PmcResult py::class_>(m, "PmcResult", "Holds the results after parametric model checking") diff --git a/tests/core/test_modelchecking.py b/tests/core/test_modelchecking.py index e6ba3c8..d6ec19d 100644 --- a/tests/core/test_modelchecking.py +++ b/tests/core/test_modelchecking.py @@ -45,3 +45,13 @@ class TestModelChecking: constraints_graph_preserving = result.constraints_graph_preserving for constraint in constraints_graph_preserving: assert constraint.rel() == pycarl.formula.Relation.GREATER + + def test_model_checking_prob01(self): + program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) + formulas = stormpy.parse_formulas_for_prism_program("P=? [ F \"one\" ]", program) + model = stormpy.build_model(program, formulas[0]) + phiStates = stormpy.BitVector(model.nr_states, True) + psiStates = stormpy.BitVector(model.nr_states, [model.nr_states-1]) + (prob0, prob1) = stormpy.compute_prob01states(model, phiStates, psiStates) + assert prob0.number_of_set_bits() == 9 + assert prob1.number_of_set_bits() == 1