From ef38b73227e8c8c520a07732a30a635a61e76d1a Mon Sep 17 00:00:00 2001 From: Matthias Volk Date: Wed, 15 Nov 2017 17:27:44 +0100 Subject: [PATCH] Added binding for SparseMatrix::getSubmatrix --- src/storage/matrix.cpp | 37 +++++++++++++++++++++--------------- tests/storage/test_matrix.py | 27 ++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/src/storage/matrix.cpp b/src/storage/matrix.cpp index 9b6e3e0..0e09ffc 100644 --- a/src/storage/matrix.cpp +++ b/src/storage/matrix.cpp @@ -1,5 +1,6 @@ #include "matrix.h" #include "storm/storage/SparseMatrix.h" +#include "storm/storage/BitVector.h" #include "src/helpers.h" template using SparseMatrix = storm::storage::SparseMatrix; @@ -52,6 +53,9 @@ void define_sparse_matrix(py::module& m) { } return stream.str(); }, py::arg("row"), "Print rows from start to end") + .def("submatrix", [](SparseMatrix const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) { + return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries); + }, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix") // Entry_index lead to problems .def("row_iter", [](SparseMatrix& matrix, row_index start, row_index end) { return py::make_iterator(matrix.begin(start), matrix.end(end)); @@ -60,10 +64,10 @@ void define_sparse_matrix(py::module& m) { // (partial) container interface to allow e.g. matrix[7:9] .def("__len__", &SparseMatrix::getRowCount) .def("__getitem__", [](SparseMatrix& matrix, entry_index i) { - if (i >= matrix.getRowCount()) - throw py::index_error(); - return matrix.getRows(i, i+1); - }, py::return_value_policy::reference, py::keep_alive<1, 0>()) + if (i >= matrix.getRowCount()) + throw py::index_error(); + return matrix.getRows(i, i+1); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) .def("__getitem__", [](SparseMatrix& matrix, py::slice slice) { size_t start, stop, step, slice_length; if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) @@ -98,6 +102,9 @@ void define_sparse_matrix(py::module& m) { } return stream.str(); }, py::arg("row"), "Print row") + .def("submatrix", [](SparseMatrix const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) { + return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries); + }, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix") // Entry_index lead to problems .def("row_iter", [](SparseMatrix& matrix, row_index start, row_index end) { return py::make_iterator(matrix.begin(start), matrix.end(end)); @@ -106,18 +113,18 @@ void define_sparse_matrix(py::module& m) { // (partial) container interface to allow e.g. matrix[7:9] .def("__len__", &SparseMatrix::getRowCount) .def("__getitem__", [](SparseMatrix& matrix, entry_index i) { - if (i >= matrix.getRowCount()) - throw py::index_error(); - return matrix.getRows(i, i+1); - }, py::return_value_policy::reference, py::keep_alive<1, 0>()) + if (i >= matrix.getRowCount()) + throw py::index_error(); + return matrix.getRows(i, i+1); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) .def("__getitem__", [](SparseMatrix& matrix, py::slice slice) { - size_t start, stop, step, slice_length; - if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) - throw py::error_already_set(); - if (step != 1) - throw py::value_error(); // not supported - return matrix.getRows(start, stop); - }, py::return_value_policy::reference, py::keep_alive<1, 0>()) + size_t start, stop, step, slice_length; + if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) + throw py::error_already_set(); + if (step != 1) + throw py::value_error(); // not supported + return matrix.getRows(start, stop); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) ; // Rows diff --git a/tests/storage/test_matrix.py b/tests/storage/test_matrix.py index 3a29e8a..8750483 100644 --- a/tests/storage/test_matrix.py +++ b/tests/storage/test_matrix.py @@ -5,7 +5,7 @@ import math class TestMatrix: - def test_sparse_matrix(self): + def test_matrix(self): model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) matrix = model.transition_matrix @@ -29,7 +29,7 @@ class TestMatrix: for e in matrix: assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6) - def test_change_sparse_matrix(self): + def test_change_matrix(self): model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) matrix = model.transition_matrix @@ -44,7 +44,7 @@ class TestMatrix: assert e.value() == i i += 0.1 - def test_change_sparse_matrix_modelchecking(self): + def test_change_matrix_modelchecking(self): import stormpy.logic model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), get_example_path("dtmc", "die.lab")) @@ -87,7 +87,7 @@ class TestMatrix: resValue = result.at(model.initial_states[0]) assert math.isclose(resValue, 0.3555555555555556) - def test_change_parametric_sparse_matrix_modelchecking(self): + def test_change_parametric_matrix_modelchecking(self): import stormpy.logic program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm")) @@ -120,3 +120,22 @@ class TestMatrix: result = stormpy.model_checking(model, formulas[0]) ratFunc = result.at(initial_state) assert len(ratFunc.gather_variables()) == 0 + + def test_submatrix(self): + model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), + get_example_path("dtmc", "die.lab")) + matrix = model.transition_matrix + assert matrix.nr_rows == 13 + assert matrix.nr_columns == 13 + assert matrix.nr_entries == 20 + assert matrix.nr_entries == model.nr_transitions + for e in matrix: + assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6) + + row_constraint = stormpy.BitVector(13, [0, 1, 3, 4, 7, 8, 9]) + submatrix = matrix.submatrix(row_constraint, row_constraint) + assert submatrix.nr_rows == 7 + assert submatrix.nr_columns == 7 + assert submatrix.nr_entries == 10 + for e in submatrix: + assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 3)