From 2da3e6eaad3a5421d558c767c7adea6ec7dd807b Mon Sep 17 00:00:00 2001
From: Mavo <matthias.volk@rwth-aachen.de>
Date: Fri, 27 May 2016 10:59:43 +0200
Subject: [PATCH] Python iterators for models

Former-commit-id: 44ea006d6216edc30ccf8fb387eabf8a7b1e8209
---
 stormpy/lib/stormpy/storage/__init__.py       |  1 +
 stormpy/lib/stormpy/storage/action.py         | 34 ++++++++++
 stormpy/lib/stormpy/storage/state.py          | 34 ++++++++++
 stormpy/src/storage/matrix.cpp                | 32 ++++++++++
 stormpy/tests/storage/test_model_iterators.py | 64 +++++++++++++++++++
 5 files changed, 165 insertions(+)
 create mode 100644 stormpy/lib/stormpy/storage/action.py
 create mode 100644 stormpy/lib/stormpy/storage/state.py
 create mode 100644 stormpy/tests/storage/test_model_iterators.py

diff --git a/stormpy/lib/stormpy/storage/__init__.py b/stormpy/lib/stormpy/storage/__init__.py
index 48053d030..c68800f3a 100644
--- a/stormpy/lib/stormpy/storage/__init__.py
+++ b/stormpy/lib/stormpy/storage/__init__.py
@@ -1,2 +1,3 @@
 from . import storage
 from .storage import *
+from . import state,action
diff --git a/stormpy/lib/stormpy/storage/action.py b/stormpy/lib/stormpy/storage/action.py
new file mode 100644
index 000000000..86bf2cd02
--- /dev/null
+++ b/stormpy/lib/stormpy/storage/action.py
@@ -0,0 +1,34 @@
+class Action:
+    """ Represents an action in the model """
+
+    def __init__(self, row_group_start, row_group_end, row, model):
+        """ Initialize
+        :param row_group_start: Start index of the row group in the matrix
+        :param row_group_end: End index of the row group in the matrix
+        :param row: Index of the corresponding row in the matrix
+        :param model: Corresponding model
+        """
+        self.row_group_start = row_group_start
+        self.row_group_end = row_group_end
+        self.row = row - 1
+        self.model = model
+        assert row >= -1 and row + row_group_start <= row_group_end
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.row + self.row_group_start >= self.row_group_end - 1:
+            raise StopIteration
+        else:
+            self.row += 1
+            return self
+
+    def __str__(self):
+        return "{}".format(self.row)
+
+    def transitions(self):
+        """ Get transitions associated with the action
+        :return List of tranistions
+        """
+        return self.model.transition_matrix().get_row(self.row + self.row_group_start)
diff --git a/stormpy/lib/stormpy/storage/state.py b/stormpy/lib/stormpy/storage/state.py
new file mode 100644
index 000000000..1faa5fc88
--- /dev/null
+++ b/stormpy/lib/stormpy/storage/state.py
@@ -0,0 +1,34 @@
+import stormpy.storage
+
+class State:
+    """ Represents a state in the model """
+
+    def __init__(self, id, model):
+        """ Initialize
+        :param id: Id of the state
+        :param model: Corresponding model
+        """
+        self.id = id - 1
+        self.model = model
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        if self.id >= self.model.nr_states() - 1:
+            raise StopIteration
+        else:
+            self.id += 1
+            return self
+
+    def __str__(self):
+        return "{}".format(self.id)
+
+    def actions(self):
+        """ Get actions associated with the state
+        :return List of actions
+        """
+        row_group_indices = self.model.transition_matrix().row_group_indices()
+        start = row_group_indices[self.id]
+        end = row_group_indices[self.id+1]
+        return stormpy.action.Action(start, end, 0, self.model)
diff --git a/stormpy/src/storage/matrix.cpp b/stormpy/src/storage/matrix.cpp
index 6db2de7c3..860255a81 100644
--- a/stormpy/src/storage/matrix.cpp
+++ b/stormpy/src/storage/matrix.cpp
@@ -22,9 +22,41 @@ void define_sparse_matrix(py::module& m) {
         .def("__iter__", [](storm::storage::SparseMatrix<double> const& matrix) {
                 return py::make_iterator(matrix.begin(), matrix.end());
             }, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
+        .def("__str__", [](storm::storage::SparseMatrix<double> const& matrix) {
+                std::stringstream stream;
+                stream << matrix;
+                return stream.str();
+            })
         .def("nr_rows", &storm::storage::SparseMatrix<double>::getRowCount, "Number of rows")
         .def("nr_columns", &storm::storage::SparseMatrix<double>::getColumnCount, "Number of columns")
         .def("nr_entries", &storm::storage::SparseMatrix<double>::getEntryCount, "Number of non-zero entries")
+        .def("row_group_indices", &storm::storage::SparseMatrix<double>::getRowGroupIndices, "Number of non-zero entries")
+        .def("get_row", [](storm::storage::SparseMatrix<double>& matrix, entry_index row) {
+                return matrix.getRows(row, row+1);
+            }, py::keep_alive<0, 1>() /* keep_alive seems to avoid problem with wrong values */, "Get rows from start to end")
+        .def("get_rows", [](storm::storage::SparseMatrix<double>& matrix, entry_index start, entry_index end) {
+                return matrix.getRows(start, end);
+            }, "Get rows from start to end")
+        .def("print_row", [](storm::storage::SparseMatrix<double>& matrix, entry_index row) {
+                std::stringstream stream;
+                auto rows = matrix.getRows(row, row+1);
+                for (auto transition : rows) {
+                    stream << transition << ", ";
+                }
+                return stream.str();
+            })
     ;
 
+    py::class_<storm::storage::SparseMatrix<double>::rows>(m, "SparseMatrixRows", "Set of rows in a sparse matrix")
+        .def("__iter__", [](storm::storage::SparseMatrix<double>::rows& rows) {
+                return py::make_iterator(rows.begin(), rows.end());
+            }, py::keep_alive<0, 1>())
+        .def("__str__", [](storm::storage::SparseMatrix<double>::rows& rows) {
+                std::stringstream stream;
+                for (auto transition : rows) {
+                    stream << transition << ", ";
+                }
+                return stream.str();
+            })
+    ;
 }
diff --git a/stormpy/tests/storage/test_model_iterators.py b/stormpy/tests/storage/test_model_iterators.py
new file mode 100644
index 000000000..9b090b9bb
--- /dev/null
+++ b/stormpy/tests/storage/test_model_iterators.py
@@ -0,0 +1,64 @@
+import stormpy
+
+class TestModelIterators:
+    def test_states_dtmc(self):
+        model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
+        s = stormpy.state.State(0, model)
+        i = 0
+        for state in stormpy.state.State(0, model):
+            assert state.id == i
+            i += 1
+        assert i == model.nr_states()
+    
+    def test_states_mdp(self):
+        model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
+        s = stormpy.state.State(0, model)
+        i = 0
+        for state in stormpy.state.State(0, model):
+            assert state.id == i
+            i += 1
+        assert i == model.nr_states()
+    
+    def test_actions_dtmc(self):
+        model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
+        s = stormpy.state.State(0, model)
+        for state in stormpy.state.State(0, model):
+            for action in state.actions():
+                assert action.row == 0
+    
+    def test_actions_mdp(self):
+        model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
+        s = stormpy.state.State(0, model)
+        for state in stormpy.state.State(0, model):
+            for action in state.actions():
+                assert action.row == 0 or action.row == 1
+    
+    def test_transitions_dtmc(self):
+        transitions_orig = [(0, 0, 0), (0, 1, 0.5), (0, 2, 0.5), (1, 1, 0), (1, 3, 0.5), (1, 4, 0.5), 
+                (2, 2, 0), (2, 5, 0.5), (2, 6, 0.5), (3, 1, 0.5), (3, 3, 0), (3, 7, 0.5),
+                (4, 4, 0), (4, 8, 0.5), (4, 9, 0.5), (5, 5, 0), (5, 10, 0.5), (5, 11, 0.5),
+                (6, 2, 0.5), (6, 6, 0), (6, 12, 0.5), (7, 7, 1), (8, 8, 1),
+                (9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1)
+            ]
+        model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
+        s = stormpy.state.State(0, model)
+        i = 0
+        for state in stormpy.state.State(0, model):
+            for action in state.actions():
+                for transition in action.transitions():
+                    transition_orig = transitions_orig[i]
+                    i += 1
+                    assert state.id == transition_orig[0]
+                    assert transition.column() == transition_orig[1]
+                    assert transition.val() == transition_orig[2]
+    
+    def test_transitions_mdp(self):
+        model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
+        s = stormpy.state.State(0, model)
+        for state in stormpy.state.State(0, model):
+            i = 0
+            for action in state.actions():
+                i += 1
+                for transition in action.transitions():
+                    assert transition.val() == 0.5 or transition.val() == 1
+            assert i == 1 or i == 2