From 9618b5ca31a7f91c50aaa9a3caebe05f437c27ce Mon Sep 17 00:00:00 2001 From: Matthias Volk Date: Thu, 16 Mar 2017 14:16:55 +0100 Subject: [PATCH] Length for states and actions --- src/storage/matrix.cpp | 2 ++ src/storage/state.cpp | 4 ++++ src/storage/state.h | 26 +++++++++++++++++--------- tests/storage/test_state.py | 5 +++++ 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/storage/matrix.cpp b/src/storage/matrix.cpp index 5073c42..b5a8554 100644 --- a/src/storage/matrix.cpp +++ b/src/storage/matrix.cpp @@ -107,6 +107,7 @@ void define_sparse_matrix(py::module& m) { .def("__iter__", [](storm::storage::SparseMatrix::rows& rows) { return py::make_iterator(rows.begin(), rows.end()); }, py::keep_alive<0, 1>()) + .def("__len__", &storm::storage::SparseMatrix::rows::getNumberOfEntries) .def("__str__", [](storm::storage::SparseMatrix::const_rows& rows) { std::stringstream stream; for (auto transition : rows) { @@ -120,6 +121,7 @@ void define_sparse_matrix(py::module& m) { .def("__iter__", [](storm::storage::SparseMatrix::rows& rows) { return py::make_iterator(rows.begin(), rows.end()); }, py::keep_alive<0, 1>()) + .def("__len__", &storm::storage::SparseMatrix::rows::getNumberOfEntries) .def("__str__", [](storm::storage::SparseMatrix::const_rows& rows) { std::stringstream stream; for (auto transition : rows) { diff --git a/src/storage/state.cpp b/src/storage/state.cpp index d3724c1..fadd676 100644 --- a/src/storage/state.cpp +++ b/src/storage/state.cpp @@ -5,9 +5,11 @@ void define_state(py::module& m) { // SparseModelStates py::class_>(m, "SparseModelStates", "States in sparse model") .def("__getitem__", &SparseModelStates::getState) + .def("__len__", &SparseModelStates::getSize) ; py::class_>(m, "SparseParametricModelStates", "States in sparse parametric model") .def("__getitem__", &SparseModelStates::getState) + .def("__len__", &SparseModelStates::getSize) ; // SparseModelState py::class_>(m, "SparseModelState", "State in sparse model") @@ -26,9 +28,11 @@ void define_state(py::module& m) { // SparseModelActions py::class_>(m, "SparseModelActions", "Actions for state in sparse model") .def("__getitem__", &SparseModelActions::getAction) + .def("__len__", &SparseModelActions::getSize) ; py::class_>(m, "SparseParametricModelActions", "Actions for state in sparse parametric model") .def("__getitem__", &SparseModelActions::getAction) + .def("__len__", &SparseModelActions::getSize) ; // SparseModelAction py::class_>(m, "SparseModelAction", "Action for state in sparse model") diff --git a/src/storage/state.h b/src/storage/state.h index a2db40e..0e91a11 100644 --- a/src/storage/state.h +++ b/src/storage/state.h @@ -20,20 +20,20 @@ class SparseModelState { } s_index getIndex() const { - return this->stateIndex; + return stateIndex; } - std::set getLabels() { + std::set getLabels() const { return this->model.getStateLabeling().getLabelsOfState(this->stateIndex); } - SparseModelActions getActions() { + SparseModelActions getActions() const { return SparseModelActions(this->model, stateIndex); } - std::string toString() { + std::string toString() const { std::stringstream stream; - stream << this->getIndex(); + stream << stateIndex; return stream.str(); } @@ -51,7 +51,11 @@ class SparseModelStates { length = model.getNumberOfStates(); } - SparseModelState getState(s_index index) { + s_index getSize() const { + return length; + } + + SparseModelState getState(s_index index) const { if (index < length) { return SparseModelState(model, index); } else { @@ -81,9 +85,9 @@ class SparseModelAction { return model.getTransitionMatrix().getRow(stateIndex, actionIndex); } - std::string toString() { + std::string toString() const { std::stringstream stream; - stream << this->getIndex(); + stream << actionIndex; return stream.str(); } @@ -103,7 +107,11 @@ class SparseModelActions { length = model.getTransitionMatrix().getRowGroupSize(stateIndex); } - SparseModelAction getAction(size_t index) { + s_index getSize() const { + return length; + } + + SparseModelAction getAction(size_t index) const { if (index < length) { return SparseModelAction(model, stateIndex, index); } else { diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index add9dcd..a46c620 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -6,6 +6,7 @@ class TestState: model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) i = 0 states = model.states + assert len(states) == 13 for state in states: assert state.id == i i += 1 @@ -18,6 +19,7 @@ class TestState: def test_states_mdp(self): model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) i = 0 + assert len(model.states) == 169 for state in model.states: assert state.id == i i += 1 @@ -41,12 +43,14 @@ class TestState: def test_actions_dtmc(self): model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) for state in model.states: + assert len(state.actions) == 1 for action in state.actions: assert action.id == 0 def test_actions_mdp(self): model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) for state in model.states: + assert len(state.actions) == 1 or len(state.actions) == 2 for action in state.actions: assert action.id == 0 or action.id == 1 @@ -61,6 +65,7 @@ class TestState: i = 0 for state in model.states: for action in state.actions: + assert (state.id < 7 and len(action.transitions) == 3) or (state.id >= 7 and len(action.transitions) == 1) for transition in action.transitions: transition_orig = transitions_orig[i] i += 1