diff --git a/lib/stormpy/storage/action.py b/lib/stormpy/storage/action.py index 420f5b3..ad6f68e 100644 --- a/lib/stormpy/storage/action.py +++ b/lib/stormpy/storage/action.py @@ -1,17 +1,17 @@ class Action: """ Represents an action in the model """ - def __init__(self, row_group_start, row_group_end, row, model): + def __init__(self, row_group_start, row_group_end, row, matrix): """ Initialize :param row_group_start: Start index of the row group in the matrix :param row_group_end: End index of the row group in the matrix :param row: Index of the corresponding row in the matrix - :param model: Corresponding model + :param matrix: Corresponding matrix """ self.row_group_start = row_group_start self.row_group_end = row_group_end self.row = row - 1 - self.model = model + self.matrix = matrix assert row >= -1 and row + row_group_start <= row_group_end def __iter__(self): @@ -33,4 +33,4 @@ class Action: """ row = self.row_group_start + self.row #return self.model.transition_matrix().get_row(self.row_group_start + self.row) - return self.model.transition_matrix.row_iter(row, row) + return self.matrix.row_iter(row, row) diff --git a/lib/stormpy/storage/state.py b/lib/stormpy/storage/state.py index 6a315e7..613316e 100644 --- a/lib/stormpy/storage/state.py +++ b/lib/stormpy/storage/state.py @@ -1,21 +1,21 @@ from . import action class State: - """ Represents a state in the model """ + """ Represents a state in the matrix """ - def __init__(self, id, model): + def __init__(self, id, matrix): """ Initialize :param id: Id of the state - :param model: Corresponding model + :param matrix: Corresponding matrix """ self.id = id - 1 - self.model = model + self.matrix = matrix def __iter__(self): return self def __next__(self): - if self.id >= self.model.nr_states - 1: + if self.id >= self.matrix.nr_row_groups - 1: raise StopIteration else: self.id += 1 @@ -28,7 +28,7 @@ class State: """ Get actions associated with the state :return List of actions """ - row_group_indices = self.model.transition_matrix._row_group_indices + row_group_indices = self.matrix._row_group_indices start = row_group_indices[self.id] end = row_group_indices[self.id+1] - return action.Action(start, end, 0, self.model) + return action.Action(start, end, 0, self.matrix) diff --git a/src/storage/matrix.cpp b/src/storage/matrix.cpp index 887a866..567cbdf 100644 --- a/src/storage/matrix.cpp +++ b/src/storage/matrix.cpp @@ -44,6 +44,7 @@ void define_sparse_matrix(py::module& m) { return stream.str(); }) .def_property_readonly("nr_rows", &storm::storage::SparseMatrix::getRowCount, "Number of rows") + .def_property_readonly("nr_row_groups", &storm::storage::SparseMatrix::getRowGroupCount, "Number of row groups") .def_property_readonly("nr_columns", &storm::storage::SparseMatrix::getColumnCount, "Number of columns") .def_property_readonly("nr_entries", &storm::storage::SparseMatrix::getEntryCount, "Number of non-zero entries") .def_property_readonly("_row_group_indices", &storm::storage::SparseMatrix::getRowGroupIndices, "Starting rows of row groups") @@ -77,6 +78,7 @@ void define_sparse_matrix(py::module& m) { return stream.str(); }) .def_property_readonly("nr_rows", &storm::storage::SparseMatrix::getRowCount, "Number of rows") + .def_property_readonly("nr_row_groups", &storm::storage::SparseMatrix::getRowGroupCount, "Number of row groups") .def_property_readonly("nr_columns", &storm::storage::SparseMatrix::getColumnCount, "Number of columns") .def_property_readonly("nr_entries", &storm::storage::SparseMatrix::getEntryCount, "Number of non-zero entries") .def_property_readonly("_row_group_indices", &storm::storage::SparseMatrix::getRowGroupIndices, "Starting rows of row groups") diff --git a/src/storage/model.cpp b/src/storage/model.cpp index 17c71ec..839ee4e 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -22,6 +22,11 @@ storm::storage::SparseMatrix& getTransitionMatrix(storm::models::spar return model.getTransitionMatrix(); } +template +storm::storage::SparseMatrix getBackwardTransitionMatrix(storm::models::sparse::Model& model) { + return model.getBackwardTransitions(); +} + std::set probabilityVariables(storm::models::sparse::Model const& model) { return storm::models::sparse::getProbabilityParameters(model); } @@ -72,6 +77,7 @@ void define_model(py::module& m) { .def("labels_state", &storm::models::sparse::Model::getLabelsOfState, py::arg("state"), "Get labels of state") .def_property_readonly("initial_states", &getInitialStates, "Initial states") .def_property_readonly("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix") + .def_property_readonly("backward_transition_matrix", &getBackwardTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix") ; py::class_, std::shared_ptr>>(m, "SparseDtmc", "DTMC in sparse representation", model) ; @@ -89,6 +95,7 @@ void define_model(py::module& m) { .def("labels_state", &storm::models::sparse::Model::getLabelsOfState, py::arg("state"), "Get labels of state") .def_property_readonly("initial_states", &getInitialStates, "Initial states") .def_property_readonly("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix") + .def_property_readonly("backward_transition_matrix", &getBackwardTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix") ; py::class_, std::shared_ptr>>(m, "SparseParametricDtmc", "pDTMC in sparse representation", modelRatFunc) ; diff --git a/tests/storage/test_matrix.py b/tests/storage/test_matrix.py index 4ca2faf..53fdc91 100644 --- a/tests/storage/test_matrix.py +++ b/tests/storage/test_matrix.py @@ -14,6 +14,16 @@ class TestMatrix: for e in matrix: assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6) + def test_backward_matrix(self): + model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) + matrix = model.backward_transition_matrix + assert type(matrix) is stormpy.storage.SparseMatrix + assert matrix.nr_rows == model.nr_states + assert matrix.nr_columns == model.nr_states + assert matrix.nr_entries == 20 #model.nr_transitions + for e in matrix: + assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6) + def test_change_sparse_matrix(self): model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) matrix = model.transition_matrix @@ -58,7 +68,7 @@ class TestMatrix: assert math.isclose(resValue, 0.06923076923076932) # Change probabilities again - for state in stormpy.state.State(0, model): + for state in stormpy.state.State(0, model.transition_matrix): for action in state.actions(): for transition in action.transitions(): if transition.value() == 0.3: diff --git a/tests/storage/test_model_iterators.py b/tests/storage/test_model_iterators.py index 8023b86..1f80e97 100644 --- a/tests/storage/test_model_iterators.py +++ b/tests/storage/test_model_iterators.py @@ -4,33 +4,33 @@ from helpers.helper import get_example_path class TestModelIterators: def test_states_dtmc(self): model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) - s = stormpy.state.State(0, model) + s = stormpy.state.State(0, model.transition_matrix) i = 0 - for state in stormpy.state.State(0, model): + for state in stormpy.state.State(0, model.transition_matrix): assert state.id == i i += 1 assert i == model.nr_states def test_states_mdp(self): model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) - s = stormpy.state.State(0, model) + s = stormpy.state.State(0, model.transition_matrix) i = 0 - for state in stormpy.state.State(0, model): + for state in stormpy.state.State(0, model.transition_matrix): assert state.id == i i += 1 assert i == model.nr_states def test_actions_dtmc(self): model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) - s = stormpy.state.State(0, model) - for state in stormpy.state.State(0, model): + s = stormpy.state.State(0, model.transition_matrix) + for state in stormpy.state.State(0, model.transition_matrix): for action in state.actions(): assert action.row == 0 def test_actions_mdp(self): model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) - s = stormpy.state.State(0, model) - for state in stormpy.state.State(0, model): + s = stormpy.state.State(0, model.transition_matrix) + for state in stormpy.state.State(0, model.transition_matrix): for action in state.actions(): assert action.row == 0 or action.row == 1 @@ -42,9 +42,9 @@ class TestModelIterators: (9, 9, 1), (10, 10, 1), (11, 11, 1), (12, 12, 1) ] model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) - s = stormpy.state.State(0, model) + s = stormpy.state.State(0, model.transition_matrix) i = 0 - for state in stormpy.state.State(0, model): + for state in stormpy.state.State(0, model.transition_matrix): for action in state.actions(): for transition in action.transitions(): transition_orig = transitions_orig[i] @@ -55,8 +55,8 @@ class TestModelIterators: def test_transitions_mdp(self): model = stormpy.parse_explicit_model(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab")) - s = stormpy.state.State(0, model) - for state in stormpy.state.State(0, model): + s = stormpy.state.State(0, model.transition_matrix) + for state in stormpy.state.State(0, model.transition_matrix): i = 0 for action in state.actions(): i += 1 @@ -73,7 +73,7 @@ class TestModelIterators: ] model = stormpy.parse_explicit_model(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) i = 0 - for state in stormpy.state.State(0, model): + for state in stormpy.state.State(0, model.transition_matrix): for transition in model.transition_matrix.row_iter(state.id, state.id): transition_orig = transitions_orig[i] i += 1