Browse Source

add Matrix slicing

refactoring
Tom Janson 8 years ago
parent
commit
a97de7a078
  1. 32
      src/storage/matrix.cpp

32
src/storage/matrix.cpp

@ -54,6 +54,22 @@ void define_sparse_matrix(py::module& m) {
.def("row_iter", [](storm::storage::SparseMatrix<double>& matrix, row_index start, row_index end) {
return py::make_iterator(matrix.begin(start), matrix.end(end));
}, py::keep_alive<0, 1>() /* keep object alive while iterator exists */, py::arg("row_start"), py::arg("row_end"), "Get iterator from start to end")
// (partial) container interface to allow e.g. matrix[7:9]
.def("__len__", &storm::storage::SparseMatrix<double>::getRowCount)
.def("__getitem__", [](storm::storage::SparseMatrix<double>& matrix, entry_index i) {
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
.def("__getitem__", [](storm::storage::SparseMatrix<double>& matrix, py::slice slice) {
size_t start, stop, step, slice_length;
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length))
throw py::error_already_set();
if (step != 1)
throw py::value_error(); // not supported
return matrix.getRows(start, stop);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
;
py::class_<storm::storage::SparseMatrix<storm::RationalFunction>>(m, "ParametricSparseMatrix", "Parametric sparse matrix")
@ -83,6 +99,22 @@ void define_sparse_matrix(py::module& m) {
.def("row_iter", [](storm::storage::SparseMatrix<storm::RationalFunction>& matrix, parametric_row_index start, parametric_row_index end) {
return py::make_iterator(matrix.begin(start), matrix.end(end));
}, py::keep_alive<0, 1>() /* keep object alive while iterator exists */, py::arg("row_start"), py::arg("row_end"), "Get iterator from start to end")
// (partial) container interface to allow e.g. matrix[7:9]
.def("__len__", &storm::storage::SparseMatrix<storm::RationalFunction>::getRowCount)
.def("__getitem__", [](storm::storage::SparseMatrix<storm::RationalFunction>& matrix, parametric_entry_index i) {
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
.def("__getitem__", [](storm::storage::SparseMatrix<storm::RationalFunction>& matrix, py::slice slice) {
size_t start, stop, step, slice_length;
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length))
throw py::error_already_set();
if (step != 1)
throw py::value_error(); // not supported
return matrix.getRows(start, stop);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
;
// Rows

Loading…
Cancel
Save