|
|
@ -1,5 +1,6 @@ |
|
|
|
#include "matrix.h"
|
|
|
|
#include "storm/storage/SparseMatrix.h"
|
|
|
|
#include "storm/storage/BitVector.h"
|
|
|
|
#include "src/helpers.h"
|
|
|
|
|
|
|
|
template<typename ValueType> using SparseMatrix = storm::storage::SparseMatrix<ValueType>; |
|
|
@ -52,6 +53,9 @@ void define_sparse_matrix(py::module& m) { |
|
|
|
} |
|
|
|
return stream.str(); |
|
|
|
}, py::arg("row"), "Print rows from start to end") |
|
|
|
.def("submatrix", [](SparseMatrix<double> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) { |
|
|
|
return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries); |
|
|
|
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix") |
|
|
|
// Entry_index lead to problems
|
|
|
|
.def("row_iter", [](SparseMatrix<double>& matrix, row_index start, row_index end) { |
|
|
|
return py::make_iterator(matrix.begin(start), matrix.end(end)); |
|
|
@ -60,10 +64,10 @@ void define_sparse_matrix(py::module& m) { |
|
|
|
// (partial) container interface to allow e.g. matrix[7:9]
|
|
|
|
.def("__len__", &SparseMatrix<double>::getRowCount) |
|
|
|
.def("__getitem__", [](SparseMatrix<double>& matrix, entry_index<double> i) { |
|
|
|
if (i >= matrix.getRowCount()) |
|
|
|
throw py::index_error(); |
|
|
|
return matrix.getRows(i, i+1); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
if (i >= matrix.getRowCount()) |
|
|
|
throw py::index_error(); |
|
|
|
return matrix.getRows(i, i+1); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
.def("__getitem__", [](SparseMatrix<double>& matrix, py::slice slice) { |
|
|
|
size_t start, stop, step, slice_length; |
|
|
|
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) |
|
|
@ -98,6 +102,9 @@ void define_sparse_matrix(py::module& m) { |
|
|
|
} |
|
|
|
return stream.str(); |
|
|
|
}, py::arg("row"), "Print row") |
|
|
|
.def("submatrix", [](SparseMatrix<RationalFunction> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) { |
|
|
|
return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries); |
|
|
|
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix") |
|
|
|
// Entry_index lead to problems
|
|
|
|
.def("row_iter", [](SparseMatrix<RationalFunction>& matrix, row_index start, row_index end) { |
|
|
|
return py::make_iterator(matrix.begin(start), matrix.end(end)); |
|
|
@ -106,18 +113,18 @@ void define_sparse_matrix(py::module& m) { |
|
|
|
// (partial) container interface to allow e.g. matrix[7:9]
|
|
|
|
.def("__len__", &SparseMatrix<RationalFunction>::getRowCount) |
|
|
|
.def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, entry_index<RationalFunction> i) { |
|
|
|
if (i >= matrix.getRowCount()) |
|
|
|
throw py::index_error(); |
|
|
|
return matrix.getRows(i, i+1); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
if (i >= matrix.getRowCount()) |
|
|
|
throw py::index_error(); |
|
|
|
return matrix.getRows(i, i+1); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
.def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, py::slice slice) { |
|
|
|
size_t start, stop, step, slice_length; |
|
|
|
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) |
|
|
|
throw py::error_already_set(); |
|
|
|
if (step != 1) |
|
|
|
throw py::value_error(); // not supported
|
|
|
|
return matrix.getRows(start, stop); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
size_t start, stop, step, slice_length; |
|
|
|
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) |
|
|
|
throw py::error_already_set(); |
|
|
|
if (step != 1) |
|
|
|
throw py::value_error(); // not supported
|
|
|
|
return matrix.getRows(start, stop); |
|
|
|
}, py::return_value_policy::reference, py::keep_alive<1, 0>()) |
|
|
|
; |
|
|
|
|
|
|
|
// Rows
|
|
|
|