Browse Source

Added binding for SparseMatrix::getSubmatrix

refactoring
Matthias Volk 7 years ago
parent
commit
ef38b73227
  1. 37
      src/storage/matrix.cpp
  2. 27
      tests/storage/test_matrix.py

37
src/storage/matrix.cpp

@ -1,5 +1,6 @@
#include "matrix.h" #include "matrix.h"
#include "storm/storage/SparseMatrix.h" #include "storm/storage/SparseMatrix.h"
#include "storm/storage/BitVector.h"
#include "src/helpers.h" #include "src/helpers.h"
template<typename ValueType> using SparseMatrix = storm::storage::SparseMatrix<ValueType>; template<typename ValueType> using SparseMatrix = storm::storage::SparseMatrix<ValueType>;
@ -52,6 +53,9 @@ void define_sparse_matrix(py::module& m) {
} }
return stream.str(); return stream.str();
}, py::arg("row"), "Print rows from start to end") }, py::arg("row"), "Print rows from start to end")
.def("submatrix", [](SparseMatrix<double> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) {
return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries);
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix")
// Entry_index lead to problems // Entry_index lead to problems
.def("row_iter", [](SparseMatrix<double>& matrix, row_index start, row_index end) { .def("row_iter", [](SparseMatrix<double>& matrix, row_index start, row_index end) {
return py::make_iterator(matrix.begin(start), matrix.end(end)); return py::make_iterator(matrix.begin(start), matrix.end(end));
@ -60,10 +64,10 @@ void define_sparse_matrix(py::module& m) {
// (partial) container interface to allow e.g. matrix[7:9] // (partial) container interface to allow e.g. matrix[7:9]
.def("__len__", &SparseMatrix<double>::getRowCount) .def("__len__", &SparseMatrix<double>::getRowCount)
.def("__getitem__", [](SparseMatrix<double>& matrix, entry_index<double> i) { .def("__getitem__", [](SparseMatrix<double>& matrix, entry_index<double> i) {
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
.def("__getitem__", [](SparseMatrix<double>& matrix, py::slice slice) { .def("__getitem__", [](SparseMatrix<double>& matrix, py::slice slice) {
size_t start, stop, step, slice_length; size_t start, stop, step, slice_length;
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length)) if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length))
@ -98,6 +102,9 @@ void define_sparse_matrix(py::module& m) {
} }
return stream.str(); return stream.str();
}, py::arg("row"), "Print row") }, py::arg("row"), "Print row")
.def("submatrix", [](SparseMatrix<RationalFunction> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) {
return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries);
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix")
// Entry_index lead to problems // Entry_index lead to problems
.def("row_iter", [](SparseMatrix<RationalFunction>& matrix, row_index start, row_index end) { .def("row_iter", [](SparseMatrix<RationalFunction>& matrix, row_index start, row_index end) {
return py::make_iterator(matrix.begin(start), matrix.end(end)); return py::make_iterator(matrix.begin(start), matrix.end(end));
@ -106,18 +113,18 @@ void define_sparse_matrix(py::module& m) {
// (partial) container interface to allow e.g. matrix[7:9] // (partial) container interface to allow e.g. matrix[7:9]
.def("__len__", &SparseMatrix<RationalFunction>::getRowCount) .def("__len__", &SparseMatrix<RationalFunction>::getRowCount)
.def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, entry_index<RationalFunction> i) { .def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, entry_index<RationalFunction> i) {
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
if (i >= matrix.getRowCount())
throw py::index_error();
return matrix.getRows(i, i+1);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
.def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, py::slice slice) { .def("__getitem__", [](SparseMatrix<RationalFunction>& matrix, py::slice slice) {
size_t start, stop, step, slice_length;
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length))
throw py::error_already_set();
if (step != 1)
throw py::value_error(); // not supported
return matrix.getRows(start, stop);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
size_t start, stop, step, slice_length;
if (!slice.compute(matrix.getRowCount(), &start, &stop, &step, &slice_length))
throw py::error_already_set();
if (step != 1)
throw py::value_error(); // not supported
return matrix.getRows(start, stop);
}, py::return_value_policy::reference, py::keep_alive<1, 0>())
; ;
// Rows // Rows

27
tests/storage/test_matrix.py

@ -5,7 +5,7 @@ import math
class TestMatrix: class TestMatrix:
def test_sparse_matrix(self):
def test_matrix(self):
model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"),
get_example_path("dtmc", "die.lab")) get_example_path("dtmc", "die.lab"))
matrix = model.transition_matrix matrix = model.transition_matrix
@ -29,7 +29,7 @@ class TestMatrix:
for e in matrix: for e in matrix:
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6) assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6)
def test_change_sparse_matrix(self):
def test_change_matrix(self):
model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"),
get_example_path("dtmc", "die.lab")) get_example_path("dtmc", "die.lab"))
matrix = model.transition_matrix matrix = model.transition_matrix
@ -44,7 +44,7 @@ class TestMatrix:
assert e.value() == i assert e.value() == i
i += 0.1 i += 0.1
def test_change_sparse_matrix_modelchecking(self):
def test_change_matrix_modelchecking(self):
import stormpy.logic import stormpy.logic
model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"), model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"),
get_example_path("dtmc", "die.lab")) get_example_path("dtmc", "die.lab"))
@ -87,7 +87,7 @@ class TestMatrix:
resValue = result.at(model.initial_states[0]) resValue = result.at(model.initial_states[0])
assert math.isclose(resValue, 0.3555555555555556) assert math.isclose(resValue, 0.3555555555555556)
def test_change_parametric_sparse_matrix_modelchecking(self):
def test_change_parametric_matrix_modelchecking(self):
import stormpy.logic import stormpy.logic
program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm")) program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm"))
@ -120,3 +120,22 @@ class TestMatrix:
result = stormpy.model_checking(model, formulas[0]) result = stormpy.model_checking(model, formulas[0])
ratFunc = result.at(initial_state) ratFunc = result.at(initial_state)
assert len(ratFunc.gather_variables()) == 0 assert len(ratFunc.gather_variables()) == 0
def test_submatrix(self):
model = stormpy.build_sparse_model_from_explicit(get_example_path("dtmc", "die.tra"),
get_example_path("dtmc", "die.lab"))
matrix = model.transition_matrix
assert matrix.nr_rows == 13
assert matrix.nr_columns == 13
assert matrix.nr_entries == 20
assert matrix.nr_entries == model.nr_transitions
for e in matrix:
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6)
row_constraint = stormpy.BitVector(13, [0, 1, 3, 4, 7, 8, 9])
submatrix = matrix.submatrix(row_constraint, row_constraint)
assert submatrix.nr_rows == 7
assert submatrix.nr_columns == 7
assert submatrix.nr_entries == 10
for e in submatrix:
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 3)
Loading…
Cancel
Save