From 26d2b90f442a3687c9bc3e8dd6b153044f36a9e0 Mon Sep 17 00:00:00 2001 From: Matthias Volk Date: Tue, 7 Mar 2017 19:09:01 +0100 Subject: [PATCH] Bindings for state labeling --- src/mod_storage.cpp | 2 ++ src/storage/bitvector.cpp | 7 ++++- src/storage/labeling.cpp | 29 ++++++++++++++++++++ src/storage/labeling.h | 8 ++++++ src/storage/model.cpp | 13 ++++----- tests/storage/test_labeling.py | 49 ++++++++++++++++++++++++++++++++++ tests/storage/test_model.py | 26 ------------------ 7 files changed, 101 insertions(+), 33 deletions(-) create mode 100644 src/storage/labeling.cpp create mode 100644 src/storage/labeling.h create mode 100644 tests/storage/test_labeling.py diff --git a/src/mod_storage.cpp b/src/mod_storage.cpp index 003646b..658043c 100644 --- a/src/mod_storage.cpp +++ b/src/mod_storage.cpp @@ -3,6 +3,7 @@ #include "storage/bitvector.h" #include "storage/model.h" #include "storage/matrix.h" +#include "storage/labeling.h" PYBIND11_PLUGIN(storage) { py::module m("storage"); @@ -16,5 +17,6 @@ PYBIND11_PLUGIN(storage) { define_model(m); define_sparse_matrix(m); define_model_instantiator(m); + define_labeling(m); return m.ptr(); } diff --git a/src/storage/bitvector.cpp b/src/storage/bitvector.cpp index dc24960..c40e211 100644 --- a/src/storage/bitvector.cpp +++ b/src/storage/bitvector.cpp @@ -16,7 +16,12 @@ void define_bitvector(py::module& m) { .def("size", &BitVector::size) .def("number_of_set_bits", &BitVector::getNumberOfSetBits) //.def("get", &BitVector::get, "index"_a) // no idea why this does not work - .def("get", [](BitVector const& b, uint_fast64_t i) { return b.get(i); }, "index"_a) + .def("get", [](BitVector const& b, uint_fast64_t i) { + return b.get(i); + }, "index"_a) + .def("set", [](BitVector& b, uint_fast64_t i, bool v) { + b.set(i, v); + }, py::arg("index"), py::arg("value") = true, "Set") .def(py::self == py::self) .def(py::self != py::self) diff --git a/src/storage/labeling.cpp b/src/storage/labeling.cpp new file mode 100644 index 0000000..f0217e6 --- /dev/null +++ b/src/storage/labeling.cpp @@ -0,0 +1,29 @@ +#include "labeling.h" + +#include "storm/models/sparse/StateLabeling.h" + +// Define python bindings +void define_labeling(py::module& m) { + + // StateLabeling + py::class_>(m, "StateLabeling", "Labeling for states") + .def("add_label", [](storm::models::sparse::StateLabeling& labeling, std::string label) { + labeling.addLabel(label); + }, py::arg("label"), "Add label") + .def("get_labels", &storm::models::sparse::StateLabeling::getLabels, "Get all labels") + .def("get_labels_of_state", &storm::models::sparse::StateLabeling::getLabelsOfState, "Get labels of given state", py::arg("state")) + .def("contains_label", &storm::models::sparse::StateLabeling::containsLabel, "Check if the given label is contained in the labeling", py::arg("label")) + .def("add_label_to_state", &storm::models::sparse::StateLabeling::addLabelToState, "Add label to state", py::arg("label"), py::arg("state")) + .def("has_state_label", &storm::models::sparse::StateLabeling::getStateHasLabel, "Check if the given state has the given label", py::arg("label"), py::arg("state")) + .def("get_states", &storm::models::sparse::StateLabeling::getStates, "Get all states which have the given label", py::arg("label")) + .def("set_states", [](storm::models::sparse::StateLabeling& labeling, std::string const& label, storm::storage::BitVector const& states) { + labeling.setStates(label, states); + }, "Set all states which have the given label", py::arg("label"), py::arg("states")) + .def("__str__", [](storm::models::sparse::StateLabeling const& labeling) { + std::ostringstream oss; + oss << labeling; + return oss.str(); + }) + ; + +} diff --git a/src/storage/labeling.h b/src/storage/labeling.h new file mode 100644 index 0000000..1a4e420 --- /dev/null +++ b/src/storage/labeling.h @@ -0,0 +1,8 @@ +#ifndef PYTHON_STORAGE_LABELING_H_ +#define PYTHON_STORAGE_LABELING_H_ + +#include "common.h" + +void define_labeling(py::module& m); + +#endif /* PYTHON_STORAGE_LABELING_H_ */ diff --git a/src/storage/model.cpp b/src/storage/model.cpp index f3bfd31..17c71ec 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -30,6 +30,11 @@ std::set rewardVariables(storm::models::sparse: return storm::models::sparse::getRewardParameters(model); } +template +storm::models::sparse::StateLabeling& getLabeling(storm::models::sparse::Model& model) { + return model.getStateLabeling(); +} + // Define python bindings void define_model(py::module& m) { @@ -63,9 +68,7 @@ void define_model(py::module& m) { //storm::models::sparse::Model > py::class_, std::shared_ptr>> model(m, "_SparseModel", "A probabilistic model where transitions are represented by doubles and saved in a sparse matrix", modelBase); - model.def_property_readonly("labels", [](storm::models::sparse::Model& model) { - return model.getStateLabeling().getLabels(); - }, "Labels") + model.def_property_readonly("labeling", &getLabeling, "Labels") .def("labels_state", &storm::models::sparse::Model::getLabelsOfState, py::arg("state"), "Get labels of state") .def_property_readonly("initial_states", &getInitialStates, "Initial states") .def_property_readonly("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix") @@ -82,9 +85,7 @@ void define_model(py::module& m) { py::class_, std::shared_ptr>> modelRatFunc(m, "_SparseParametricModel", "A probabilistic model where transitions are represented by rational functions and saved in a sparse matrix", modelBase); modelRatFunc.def("collect_probability_parameters", &probabilityVariables, "Collect parameters") .def("collect_reward_parameters", &rewardVariables, "Collect reward parameters") - .def_property_readonly("labels", [](storm::models::sparse::Model& model) { - return model.getStateLabeling().getLabels(); - }, "Labels") + .def_property_readonly("labeling", &getLabeling, "Labels") .def("labels_state", &storm::models::sparse::Model::getLabelsOfState, py::arg("state"), "Get labels of state") .def_property_readonly("initial_states", &getInitialStates, "Initial states") .def_property_readonly("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix") diff --git a/tests/storage/test_labeling.py b/tests/storage/test_labeling.py new file mode 100644 index 0000000..190a53e --- /dev/null +++ b/tests/storage/test_labeling.py @@ -0,0 +1,49 @@ +import stormpy +import stormpy.logic +from helpers.helper import get_example_path + +class TestLabeling: + def test_set_labeling(self): + program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) + model = stormpy.build_model(program) + labeling = model.labeling + assert "tmp" not in labeling.get_labels() + assert not labeling.contains_label("tmp") + labeling.add_label("tmp") + assert labeling.contains_label("tmp") + labeling.add_label_to_state("tmp", 0) + assert labeling.has_state_label("tmp", 0) + states = labeling.get_states("tmp") + assert states.get(0) + states.set(3) + labeling.set_states("tmp", states) + assert labeling.has_state_label("tmp", 3) + + def test_label(self): + program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) + formulas = stormpy.parse_properties_for_prism_program("P=? [ F \"one\" ]", program) + model = stormpy.build_model(program, formulas) + labeling = model.labeling + labels = labeling.get_labels() + assert len(labels) == 3 + assert "init" in labels + assert "one" in labels + assert "init" in model.labels_state(0) + assert "init" in labeling.get_labels_of_state(0) + assert "one" in model.labels_state(7) + assert "one" in labeling.get_labels_of_state(7) + + def test_label_parametric(self): + program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm")) + formulas = stormpy.parse_properties_for_prism_program("P=? [ F s=5 ]", program) + model = stormpy.build_parametric_model(program, formulas) + labels = model.labeling.get_labels() + assert len(labels) == 3 + assert "init" in labels + assert "(s = 5)" in labels + assert "init" in model.labels_state(0) + assert "(s = 5)" in model.labels_state(28) + assert "(s = 5)" in model.labels_state(611) + initial_states = model.initial_states + assert len(initial_states) == 1 + assert 0 in initial_states diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py index 8ca25f4..df17d89 100644 --- a/tests/storage/test_model.py +++ b/tests/storage/test_model.py @@ -70,17 +70,6 @@ class TestModel: assert not model.has_parameters assert type(model) is stormpy.SparseParametricDtmc - def test_label(self): - program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) - formulas = stormpy.parse_properties_for_prism_program("P=? [ F \"one\" ]", program) - model = stormpy.build_model(program, formulas) - labels = model.labels - assert len(labels) == 3 - assert "init" in labels - assert "one" in labels - assert "init" in model.labels_state(0) - assert "one" in model.labels_state(7) - def test_initial_states(self): program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) formulas = stormpy.parse_properties_for_prism_program("P=? [ F \"one\" ]", program) @@ -88,18 +77,3 @@ class TestModel: initial_states = model.initial_states assert len(initial_states) == 1 assert 0 in initial_states - - def test_label_parametric(self): - program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm")) - formulas = stormpy.parse_properties_for_prism_program("P=? [ F s=5 ]", program) - model = stormpy.build_parametric_model(program, formulas) - labels = model.labels - assert len(labels) == 3 - assert "init" in labels - assert "(s = 5)" in labels - assert "init" in model.labels_state(0) - assert "(s = 5)" in model.labels_state(28) - assert "(s = 5)" in model.labels_state(611) - initial_states = model.initial_states - assert len(initial_states) == 1 - assert 0 in initial_states