diff --git a/lib/stormpy/__init__.py b/lib/stormpy/__init__.py index 1cb559d..b3907dd 100644 --- a/lib/stormpy/__init__.py +++ b/lib/stormpy/__init__.py @@ -46,3 +46,25 @@ def model_checking(model, property): return core._parametric_model_checking(model, property.raw_formula) else: return core._model_checking(model, property.raw_formula) + + +def compute_prob01_states(model, phi_states, psi_states): + if model.model_type != ModelType.DTMC: + raise ValueError("Prob 01 is only defined for DTMCs -- model must be a DTMC") + + if model.supports_parameters: + return core._compute_prob01states_rationalfunc(model, phi_states, psi_states) + else: + return core._compute_prob01states_double(model, phi_states, psi_states) + +def compute_prob01min_states(model, phi_states, psi_states): + if model.supports_parameters: + return core._compute_prob01states_min_rationalfunc(model, phi_states, psi_states) + else: + return core._compute_prob01states_min_double(model, phi_states, psi_states) + +def compute_prob01max_states(model, phi_states, psi_states): + if model.supports_parameters: + return core._compute_prob01states_max_rationalfunc(model, phi_states, psi_states) + else: + return core._compute_prob01states_max_double(model, phi_states, psi_states) diff --git a/src/core/modelchecking.cpp b/src/core/modelchecking.cpp index aac2f5a..8e526fb 100644 --- a/src/core/modelchecking.cpp +++ b/src/core/modelchecking.cpp @@ -16,16 +16,34 @@ std::shared_ptr parametricModelChecking(std::shared_ptrconstraintsGraphPreserving = constraintCollector.getGraphPreservingConstraints(); return result; } + // Thin wrapper for computing prob01 states -std::pair computeProb01(storm::models::sparse::Dtmc model, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { +template +std::pair computeProb01(storm::models::sparse::Dtmc const& model, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { return storm::utility::graph::performProb01(model, phiStates, psiStates); } +template +std::pair computeProb01min(storm::models::sparse::Mdp const& model, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { + return storm::utility::graph::performProb01Min(model, phiStates, psiStates); +} + +template +std::pair computeProb01max(storm::models::sparse::Mdp const& model, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates) { + return storm::utility::graph::performProb01Max(model, phiStates, psiStates); +} + // Define python bindings void define_modelchecking(py::module& m) { // Model checking m.def("_model_checking", &modelChecking, "Perform model checking", py::arg("model"), py::arg("formula")); m.def("_parametric_model_checking", ¶metricModelChecking, "Perform parametric model checking", py::arg("model"), py::arg("formula")); - m.def("compute_prob01states", &computeProb01, "Compute prob-0-1 states", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_double", &computeProb01, "Compute prob-0-1 states", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_rationalfunc", &computeProb01, "Compute prob-0-1 states", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_min_double", &computeProb01min, "Compute prob-0-1 states (min)", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_max_double", &computeProb01max, "Compute prob-0-1 states (max)", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_min_rationalfunc", &computeProb01min, "Compute prob-0-1 states (min)", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + m.def("_compute_prob01states_max_rationalfunc", &computeProb01max, "Compute prob-0-1 states (max)", py::arg("model"), py::arg("phi_states"), py::arg("psi_states")); + } diff --git a/tests/core/test_modelchecking.py b/tests/core/test_modelchecking.py index 814d985..7b515b6 100644 --- a/tests/core/test_modelchecking.py +++ b/tests/core/test_modelchecking.py @@ -61,6 +61,6 @@ class TestModelChecking: psiResult = stormpy.model_checking(model, formulaPsi) psiStates = psiResult.get_truth_values() assert psiStates.number_of_set_bits() == 1 - (prob0, prob1) = stormpy.compute_prob01states(model, phiStates, psiStates) + (prob0, prob1) = stormpy.compute_prob01_states(model, phiStates, psiStates) assert prob0.number_of_set_bits() == 9 assert prob1.number_of_set_bits() == 1