Browse Source

Python iterators for models

Former-commit-id: 44ea006d62
tempestpy_adaptions
Mavo 9 years ago
committed by Matthias Volk
parent
commit
2da3e6eaad
  1. 1
      stormpy/lib/stormpy/storage/__init__.py
  2. 34
      stormpy/lib/stormpy/storage/action.py
  3. 34
      stormpy/lib/stormpy/storage/state.py
  4. 32
      stormpy/src/storage/matrix.cpp
  5. 64
      stormpy/tests/storage/test_model_iterators.py

1
stormpy/lib/stormpy/storage/__init__.py

@ -1,2 +1,3 @@
from . import storage
from .storage import *
from . import state,action

34
stormpy/lib/stormpy/storage/action.py

@ -0,0 +1,34 @@
class Action:
""" Represents an action in the model """
def __init__(self, row_group_start, row_group_end, row, model):
""" Initialize
:param row_group_start: Start index of the row group in the matrix
:param row_group_end: End index of the row group in the matrix
:param row: Index of the corresponding row in the matrix
:param model: Corresponding model
"""
self.row_group_start = row_group_start
self.row_group_end = row_group_end
self.row = row - 1
self.model = model
assert row >= -1 and row + row_group_start <= row_group_end
def __iter__(self):
return self
def __next__(self):
if self.row + self.row_group_start >= self.row_group_end - 1:
raise StopIteration
else:
self.row += 1
return self
def __str__(self):
return "{}".format(self.row)
def transitions(self):
""" Get transitions associated with the action
:return List of tranistions
"""
return self.model.transition_matrix().get_row(self.row + self.row_group_start)

34
stormpy/lib/stormpy/storage/state.py

@ -0,0 +1,34 @@
import stormpy.storage
class State:
""" Represents a state in the model """
def __init__(self, id, model):
""" Initialize
:param id: Id of the state
:param model: Corresponding model
"""
self.id = id - 1
self.model = model
def __iter__(self):
return self
def __next__(self):
if self.id >= self.model.nr_states() - 1:
raise StopIteration
else:
self.id += 1
return self
def __str__(self):
return "{}".format(self.id)
def actions(self):
""" Get actions associated with the state
:return List of actions
"""
row_group_indices = self.model.transition_matrix().row_group_indices()
start = row_group_indices[self.id]
end = row_group_indices[self.id+1]
return stormpy.action.Action(start, end, 0, self.model)

32
stormpy/src/storage/matrix.cpp

@ -22,9 +22,41 @@ void define_sparse_matrix(py::module& m) {
.def("__iter__", [](storm::storage::SparseMatrix<double> const& matrix) {
return py::make_iterator(matrix.begin(), matrix.end());
}, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
.def("__str__", [](storm::storage::SparseMatrix<double> const& matrix) {
std::stringstream stream;
stream << matrix;
return stream.str();
})
.def("nr_rows", &storm::storage::SparseMatrix<double>::getRowCount, "Number of rows")
.def("nr_columns", &storm::storage::SparseMatrix<double>::getColumnCount, "Number of columns")
.def("nr_entries", &storm::storage::SparseMatrix<double>::getEntryCount, "Number of non-zero entries")
.def("row_group_indices", &storm::storage::SparseMatrix<double>::getRowGroupIndices, "Number of non-zero entries")
.def("get_row", [](storm::storage::SparseMatrix<double>& matrix, entry_index row) {
return matrix.getRows(row, row+1);
}, py::keep_alive<0, 1>() /* keep_alive seems to avoid problem with wrong values */, "Get rows from start to end")
.def("get_rows", [](storm::storage::SparseMatrix<double>& matrix, entry_index start, entry_index end) {
return matrix.getRows(start, end);
}, "Get rows from start to end")
.def("print_row", [](storm::storage::SparseMatrix<double>& matrix, entry_index row) {
std::stringstream stream;
auto rows = matrix.getRows(row, row+1);
for (auto transition : rows) {
stream << transition << ", ";
}
return stream.str();
})
;
py::class_<storm::storage::SparseMatrix<double>::rows>(m, "SparseMatrixRows", "Set of rows in a sparse matrix")
.def("__iter__", [](storm::storage::SparseMatrix<double>::rows& rows) {
return py::make_iterator(rows.begin(), rows.end());
}, py::keep_alive<0, 1>())
.def("__str__", [](storm::storage::SparseMatrix<double>::rows& rows) {
std::stringstream stream;
for (auto transition : rows) {
stream << transition << ", ";
}
return stream.str();
})
;
}

64
stormpy/tests/storage/test_model_iterators.py

@ -0,0 +1,64 @@
import stormpy
class TestModelIterators:
def test_states_dtmc(self):
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
s = stormpy.state.State(0, model)
i = 0
for state in stormpy.state.State(0, model):
assert state.id == i
i += 1
assert i == model.nr_states()
def test_states_mdp(self):
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
s = stormpy.state.State(0, model)
i = 0
for state in stormpy.state.State(0, model):
assert state.id == i
i += 1
assert i == model.nr_states()
def test_actions_dtmc(self):
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
for action in state.actions():
assert action.row == 0
def test_actions_mdp(self):
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
for action in state.actions():
assert action.row == 0 or action.row == 1
def test_transitions_dtmc(self):
transitions_orig = [(0, 0, 0), (0, 1, 0.5), (0, 2, 0.5), (1, 1, 0), (1, 3, 0.5), (1, 4, 0.5),
(2, 2, 0), (2, 5, 0.5), (2, 6, 0.5), (3, 1, 0.5), (3, 3, 0), (3, 7, 0.5),
(4, 4, 0), (4, 8, 0.5), (4, 9, 0.5), (5, 5, 0), (5, 10, 0.5), (5, 11, 0.5),
(6, 2, 0.5), (6, 6, 0), (6, 12, 0.5), (7, 7, 1), (8, 8, 1),
(9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1)
]
model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab")
s = stormpy.state.State(0, model)
i = 0
for state in stormpy.state.State(0, model):
for action in state.actions():
for transition in action.transitions():
transition_orig = transitions_orig[i]
i += 1
assert state.id == transition_orig[0]
assert transition.column() == transition_orig[1]
assert transition.val() == transition_orig[2]
def test_transitions_mdp(self):
model = stormpy.parse_explicit_model("../examples/mdp/two_dice/two_dice.tra", "../examples/mdp/two_dice/two_dice.lab")
s = stormpy.state.State(0, model)
for state in stormpy.state.State(0, model):
i = 0
for action in state.actions():
i += 1
for transition in action.transitions():
assert transition.val() == 0.5 or transition.val() == 1
assert i == 1 or i == 2
Loading…
Cancel
Save