From a97de7a0786771752d4881e2373846d7ee4a914e Mon Sep 17 00:00:00 2001 From: Tom Janson Date: Fri, 17 Mar 2017 17:38:23 +0100 Subject: [PATCH] add Matrix slicing --- src/storage/matrix.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/storage/matrix.cpp b/src/storage/matrix.cpp index 9ba8504..39149f8 100644 --- a/src/storage/matrix.cpp +++ b/src/storage/matrix.cpp @@ -54,6 +54,22 @@ void define_sparse_matrix(py::module& m) { .def("row_iter", [](storm::storage::SparseMatrix& matrix, row_index start, row_index end) { return py::make_iterator(matrix.begin(start), matrix.end(end)); }, py::keep_alive<0, 1>() /* keep object alive while iterator exists */, py::arg("row_start"), py::arg("row_end"), "Get iterator from start to end") + + // (partial) container interface to allow e.g. matrix[7:9] + .def("__len__", &storm::storage::SparseMatrix::getRowCount) + .def("__getitem__", [](storm::storage::SparseMatrix& matrix, entry_index i) { + if (i >= matrix.getRowCount()) + throw py::index_error(); + return matrix.getRows(i, i+1); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) + .def("__getitem__", [](storm::storage::SparseMatrix& matrix, py::slice slice) { + size_t start, stop, step, slice_length; + if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) + throw py::error_already_set(); + if (step != 1) + throw py::value_error(); // not supported + return matrix.getRows(start, stop); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) ; py::class_>(m, "ParametricSparseMatrix", "Parametric sparse matrix") @@ -83,6 +99,22 @@ void define_sparse_matrix(py::module& m) { .def("row_iter", [](storm::storage::SparseMatrix& matrix, parametric_row_index start, parametric_row_index end) { return py::make_iterator(matrix.begin(start), matrix.end(end)); }, py::keep_alive<0, 1>() /* keep object alive while iterator exists */, py::arg("row_start"), py::arg("row_end"), "Get iterator from start to end") + + // (partial) container interface to allow e.g. matrix[7:9] + .def("__len__", &storm::storage::SparseMatrix::getRowCount) + .def("__getitem__", [](storm::storage::SparseMatrix& matrix, parametric_entry_index i) { + if (i >= matrix.getRowCount()) + throw py::index_error(); + return matrix.getRows(i, i+1); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) + .def("__getitem__", [](storm::storage::SparseMatrix& matrix, py::slice slice) { + size_t start, stop, step, slice_length; + if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) + throw py::error_already_set(); + if (step != 1) + throw py::value_error(); // not supported + return matrix.getRows(start, stop); + }, py::return_value_policy::reference, py::keep_alive<1, 0>()) ; // Rows