From 2da3e6eaad3a5421d558c767c7adea6ec7dd807b Mon Sep 17 00:00:00 2001 From: Mavo Date: Fri, 27 May 2016 10:59:43 +0200 Subject: [PATCH] Python iterators for models Former-commit-id: 44ea006d6216edc30ccf8fb387eabf8a7b1e8209 --- stormpy/lib/stormpy/storage/__init__.py | 1 + stormpy/lib/stormpy/storage/action.py | 34 ++++++++++ stormpy/lib/stormpy/storage/state.py | 34 ++++++++++ stormpy/src/storage/matrix.cpp | 32 ++++++++++ stormpy/tests/storage/test_model_iterators.py | 64 +++++++++++++++++++ 5 files changed, 165 insertions(+) create mode 100644 stormpy/lib/stormpy/storage/action.py create mode 100644 stormpy/lib/stormpy/storage/state.py create mode 100644 stormpy/tests/storage/test_model_iterators.py diff --git a/stormpy/lib/stormpy/storage/__init__.py b/stormpy/lib/stormpy/storage/__init__.py index 48053d030..c68800f3a 100644 --- a/stormpy/lib/stormpy/storage/__init__.py +++ b/stormpy/lib/stormpy/storage/__init__.py @@ -1,2 +1,3 @@ from . import storage from .storage import * +from . import state,action diff --git a/stormpy/lib/stormpy/storage/action.py b/stormpy/lib/stormpy/storage/action.py new file mode 100644 index 000000000..86bf2cd02 --- /dev/null +++ b/stormpy/lib/stormpy/storage/action.py @@ -0,0 +1,34 @@ +class Action: + """ Represents an action in the model """ + + def __init__(self, row_group_start, row_group_end, row, model): + """ Initialize + :param row_group_start: Start index of the row group in the matrix + :param row_group_end: End index of the row group in the matrix + :param row: Index of the corresponding row in the matrix + :param model: Corresponding model + """ + self.row_group_start = row_group_start + self.row_group_end = row_group_end + self.row = row - 1 + self.model = model + assert row >= -1 and row + row_group_start <= row_group_end + + def __iter__(self): + return self + + def __next__(self): + if self.row + self.row_group_start >= self.row_group_end - 1: + raise StopIteration + else: + self.row += 1 + return self + + def __str__(self): + return "{}".format(self.row) + + def transitions(self): + """ Get transitions associated with the action + :return List of tranistions + """ + return self.model.transition_matrix().get_row(self.row + self.row_group_start) diff --git a/stormpy/lib/stormpy/storage/state.py b/stormpy/lib/stormpy/storage/state.py new file mode 100644 index 000000000..1faa5fc88 --- /dev/null +++ b/stormpy/lib/stormpy/storage/state.py @@ -0,0 +1,34 @@ +import stormpy.storage + +class State: + """ Represents a state in the model """ + + def __init__(self, id, model): + """ Initialize + :param id: Id of the state + :param model: Corresponding model + """ + self.id = id - 1 + self.model = model + + def __iter__(self): + return self + + def __next__(self): + if self.id >= self.model.nr_states() - 1: + raise StopIteration + else: + self.id += 1 + return self + + def __str__(self): + return "{}".format(self.id) + + def actions(self): + """ Get actions associated with the state + :return List of actions + """ + row_group_indices = self.model.transition_matrix().row_group_indices() + start = row_group_indices[self.id] + end = row_group_indices[self.id+1] + return stormpy.action.Action(start, end, 0, self.model) diff --git a/stormpy/src/storage/matrix.cpp b/stormpy/src/storage/matrix.cpp index 6db2de7c3..860255a81 100644 --- a/stormpy/src/storage/matrix.cpp +++ b/stormpy/src/storage/matrix.cpp @@ -22,9 +22,41 @@ void define_sparse_matrix(py::module& m) { .def("__iter__", [](storm::storage::SparseMatrix const& matrix) { return py::make_iterator(matrix.begin(), matrix.end()); }, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) + .def("__str__", [](storm::storage::SparseMatrix const& matrix) { + std::stringstream stream; + stream << matrix; + return stream.str(); + }) .def("nr_rows", &storm::storage::SparseMatrix::getRowCount, "Number of rows") .def("nr_columns", &storm::storage::SparseMatrix::getColumnCount, "Number of columns") .def("nr_entries", &storm::storage::SparseMatrix::getEntryCount, "Number of non-zero entries") + .def("row_group_indices", &storm::storage::SparseMatrix::getRowGroupIndices, "Number of non-zero entries") + .def("get_row", [](storm::storage::SparseMatrix& matrix, entry_index row) { + return matrix.getRows(row, row+1); + }, py::keep_alive<0, 1>() /* keep_alive seems to avoid problem with wrong values */, "Get rows from start to end") + .def("get_rows", [](storm::storage::SparseMatrix& matrix, entry_index start, entry_index end) { + return matrix.getRows(start, end); + }, "Get rows from start to end") + .def("print_row", [](storm::storage::SparseMatrix& matrix, entry_index row) { + std::stringstream stream; + auto rows = matrix.getRows(row, row+1); + for (auto transition : rows) { + stream << transition << ", "; + } + return stream.str(); + }) ; + py::class_::rows>(m, "SparseMatrixRows", "Set of rows in a sparse matrix") + .def("__iter__", [](storm::storage::SparseMatrix::rows& rows) { + return py::make_iterator(rows.begin(), rows.end()); + }, py::keep_alive<0, 1>()) + .def("__str__", [](storm::storage::SparseMatrix::rows& rows) { + std::stringstream stream; + for (auto transition : rows) { + stream << transition << ", "; + } + return stream.str(); + }) + ; } diff --git a/stormpy/tests/storage/test_model_iterators.py b/stormpy/tests/storage/test_model_iterators.py new file mode 100644 index 000000000..9b090b9bb --- /dev/null +++ b/stormpy/tests/storage/test_model_iterators.py @@ -0,0 +1,64 @@ +import stormpy + +class TestModelIterators: + def test_states_dtmc(self): + model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") + s = stormpy.state.State(0, model) + i = 0 + for state in stormpy.state.State(0, model): + assert state.id == i + i += 1 + assert i == model.nr_states() + + def test_states_mdp(self): + model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") + s = stormpy.state.State(0, model) + i = 0 + for state in stormpy.state.State(0, model): + assert state.id == i + i += 1 + assert i == model.nr_states() + + def test_actions_dtmc(self): + model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") + s = stormpy.state.State(0, model) + for state in stormpy.state.State(0, model): + for action in state.actions(): + assert action.row == 0 + + def test_actions_mdp(self): + model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") + s = stormpy.state.State(0, model) + for state in stormpy.state.State(0, model): + for action in state.actions(): + assert action.row == 0 or action.row == 1 + + def test_transitions_dtmc(self): + transitions_orig = [(0, 0, 0), (0, 1, 0.5), (0, 2, 0.5), (1, 1, 0), (1, 3, 0.5), (1, 4, 0.5), + (2, 2, 0), (2, 5, 0.5), (2, 6, 0.5), (3, 1, 0.5), (3, 3, 0), (3, 7, 0.5), + (4, 4, 0), (4, 8, 0.5), (4, 9, 0.5), (5, 5, 0), (5, 10, 0.5), (5, 11, 0.5), + (6, 2, 0.5), (6, 6, 0), (6, 12, 0.5), (7, 7, 1), (8, 8, 1), + (9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1) + ] + model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") + s = stormpy.state.State(0, model) + i = 0 + for state in stormpy.state.State(0, model): + for action in state.actions(): + for transition in action.transitions(): + transition_orig = transitions_orig[i] + i += 1 + assert state.id == transition_orig[0] + assert transition.column() == transition_orig[1] + assert transition.val() == transition_orig[2] + + def test_transitions_mdp(self): + model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab") + s = stormpy.state.State(0, model) + for state in stormpy.state.State(0, model): + i = 0 + for action in state.actions(): + i += 1 + for transition in action.transitions(): + assert transition.val() == 0.5 or transition.val() == 1 + assert i == 1 or i == 2