diff --git a/stormpy/src/storage/matrix.cpp b/stormpy/src/storage/matrix.cpp index 860255a81..79f47633d 100644 --- a/stormpy/src/storage/matrix.cpp +++ b/stormpy/src/storage/matrix.cpp @@ -19,7 +19,7 @@ void define_sparse_matrix(py::module& m) { ; py::class_>(m, "SparseMatrix", "Sparse matrix") - .def("__iter__", [](storm::storage::SparseMatrix const& matrix) { + .def("__iter__", [](storm::storage::SparseMatrix& matrix) { return py::make_iterator(matrix.begin(), matrix.end()); }, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) .def("__str__", [](storm::storage::SparseMatrix const& matrix) { @@ -37,7 +37,7 @@ void define_sparse_matrix(py::module& m) { .def("get_rows", [](storm::storage::SparseMatrix& matrix, entry_index start, entry_index end) { return matrix.getRows(start, end); }, "Get rows from start to end") - .def("print_row", [](storm::storage::SparseMatrix& matrix, entry_index row) { + .def("print_row", [](storm::storage::SparseMatrix const& matrix, entry_index row) { std::stringstream stream; auto rows = matrix.getRows(row, row+1); for (auto transition : rows) { diff --git a/stormpy/src/storage/model.cpp b/stormpy/src/storage/model.cpp index 64c247805..04b1eed78 100644 --- a/stormpy/src/storage/model.cpp +++ b/stormpy/src/storage/model.cpp @@ -17,7 +17,7 @@ std::vector getInitialStates(storm::models:: } // Thin wrapper for getting transition matrix -storm::storage::SparseMatrix getTransitionMatrix(storm::models::sparse::Model const& model) { +storm::storage::SparseMatrix& getTransitionMatrix(storm::models::sparse::Model& model) { return model.getTransitionMatrix(); } @@ -53,7 +53,7 @@ void define_model(py::module& m) { }, "Get labels") .def("labels_state", &storm::models::sparse::Model::getLabelsOfState, "Get labels") .def("initial_states", &getInitialStates, "Get initial states") - .def("transition_matrix", &getTransitionMatrix, "Get transition matrix") + .def("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Get transition matrix") ; py::class_, std::shared_ptr>>(m, "SparseDtmc", "DTMC in sparse representation", py::base>()) ; diff --git a/stormpy/tests/storage/test_matrix.py b/stormpy/tests/storage/test_matrix.py index cde736f03..23beb8ece 100644 --- a/stormpy/tests/storage/test_matrix.py +++ b/stormpy/tests/storage/test_matrix.py @@ -24,3 +24,42 @@ class TestMatrix: for e in matrix: assert e.val() == i i += 0.1 + + def test_change_sparse_matrix_modelchecking(self): + import stormpy.logic + model = stormpy.parse_explicit_model("../examples/dtmc/die/die.tra", "../examples/dtmc/die/die.lab") + matrix = model.transition_matrix() + # Check matrix + for e in matrix: + assert e.val() == 0.5 or e.val() == 0 or e.val() == 1 + # First model checking + formulas = stormpy.parse_formulas("P=? [ F \"one\" ]") + result = stormpy.model_checking(model, formulas[0]) + assert result == 0.16666666666666663 + + # Change probabilities + i = 0 + for e in matrix: + if e.val() == 0.5: + if i % 2 == 0: + e.set_val(0.3) + else: + e.set_val(0.7) + i += 1 + for e in matrix: + assert e.val() == 0.3 or e.val() == 0.7 or e.val() == 1 or e.val() == 0 + # Second model checking + result = stormpy.model_checking(model, formulas[0]) + assert result == 0.06923076923076932 + + # Change probabilities again + for state in stormpy.state.State(0, model): + for action in state.actions(): + for transition in action.transitions(): + if transition.val() == 0.3: + transition.set_val(0.8) + elif transition.val() == 0.7: + transition.set_val(0.2) + # Third model checking + result = stormpy.model_checking(model, formulas[0]) + assert result == 0.3555555555555556 or result == 0.3555555555555557