hannah
5 years ago
committed by
Matthias Volk
3 changed files with 130 additions and 3 deletions
-
27lib/stormpy/storage/__init__.py
-
4src/storage/matrix.cpp
-
102tests/storage/test_matrix_builder.py
@ -1,4 +1,29 @@ |
|||||
|
|
||||
import stormpy.utility |
import stormpy.utility |
||||
from . import storage |
from . import storage |
||||
from .storage import * |
from .storage import * |
||||
|
|
||||
|
|
||||
|
def build_sparse_matrix(array, row_group_indices=[]): |
||||
|
""" |
||||
|
Build a sparse matrix from numpy array. |
||||
|
|
||||
|
:param numpy array: The array. |
||||
|
:param List[double] row_group_indices: List containing the starting row of each row group. |
||||
|
:return: Sparse matrix. |
||||
|
""" |
||||
|
num_row = array.shape[0] |
||||
|
num_col = array.shape[1] |
||||
|
|
||||
|
if 0 < len(row_group_indices): |
||||
|
builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col, has_custom_row_grouping=True, |
||||
|
row_groups=len(row_group_indices)) |
||||
|
else: |
||||
|
builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col) |
||||
|
|
||||
|
for r in range(num_row): |
||||
|
if r in row_group_indices: |
||||
|
builder.new_row_group(r) |
||||
|
for c in range(num_col): |
||||
|
builder.add_next_value(r, c, array[r, c]) |
||||
|
|
||||
|
return builder.build() |
@ -0,0 +1,102 @@ |
|||||
|
import stormpy |
||||
|
import numpy as np |
||||
|
from helpers.helper import get_example_path |
||||
|
|
||||
|
import math |
||||
|
|
||||
|
|
||||
|
class TestMatrixBuilder: |
||||
|
def test_matrix_builder(self): |
||||
|
builder = stormpy.SparseMatrixBuilder(force_dimensions=True) |
||||
|
matrix = builder.build() |
||||
|
assert matrix.nr_columns == 0 |
||||
|
assert matrix.nr_rows == 0 |
||||
|
assert matrix.nr_entries == 0 |
||||
|
|
||||
|
builder_5x5 = stormpy.SparseMatrixBuilder(5, 5, force_dimensions=False) |
||||
|
|
||||
|
builder_5x5.add_next_value(0, 0, 0) |
||||
|
builder_5x5.add_next_value(0, 2, 1) |
||||
|
|
||||
|
builder_5x5.add_next_value(2, 0, 4) |
||||
|
builder_5x5.add_next_value(2, 3, 5) |
||||
|
|
||||
|
assert builder_5x5.get_last_column() == 3 |
||||
|
assert builder_5x5.get_last_row() == 2 |
||||
|
|
||||
|
builder_5x5.add_next_value(3, 1, 0.5) |
||||
|
builder_5x5.add_next_value(3, 3, 0) |
||||
|
|
||||
|
builder_5x5.add_next_value(4, 4, 0.2) |
||||
|
|
||||
|
matrix_5x5 = builder_5x5.build() |
||||
|
|
||||
|
assert matrix_5x5.nr_columns == 5 |
||||
|
assert matrix_5x5.nr_rows == 5 |
||||
|
assert matrix_5x5.nr_entries == 7 |
||||
|
|
||||
|
# todo test Replace columns |
||||
|
# builder_5x5.replace_columns... |
||||
|
|
||||
|
def test_matrix_builder_row_grouping(self): |
||||
|
|
||||
|
num_rows = 5 |
||||
|
builder = stormpy.SparseMatrixBuilder(num_rows, 6, has_custom_row_grouping=True, row_groups=2) |
||||
|
|
||||
|
builder.new_row_group(1) |
||||
|
assert builder.get_current_row_group_count() == 1 |
||||
|
|
||||
|
builder.new_row_group(4) |
||||
|
assert builder.get_current_row_group_count() == 2 |
||||
|
matrix = builder.build() |
||||
|
|
||||
|
assert matrix.get_row_group_start(0) == 1 |
||||
|
assert matrix.get_row_group_end(0) == 4 |
||||
|
|
||||
|
assert matrix.get_row_group_start(1) == 4 |
||||
|
assert matrix.get_row_group_end(1) == 5 |
||||
|
|
||||
|
def test_matrix_from_numpy(self): |
||||
|
array = np.array([[0, 2], |
||||
|
[3, 4], |
||||
|
[0.1, 24], |
||||
|
[-0.3, -4]], dtype='float64') |
||||
|
|
||||
|
matrix = stormpy.build_sparse_matrix(array) |
||||
|
|
||||
|
# Check matrix dimension |
||||
|
assert matrix.nr_rows == array.shape[0] |
||||
|
assert matrix.nr_columns == array.shape[1] |
||||
|
assert matrix.nr_entries == 8 |
||||
|
|
||||
|
# Check matrix values |
||||
|
for r in range(array.shape[1]): |
||||
|
row = matrix.get_row(r) |
||||
|
for e in row: |
||||
|
assert (e.value() == array[r, e.column]) |
||||
|
|
||||
|
def test_matrix_from_numpy_row_grouping(self): |
||||
|
array = np.array([[0, 2], |
||||
|
[3, 4], |
||||
|
[0.1, 24], |
||||
|
[-0.3, -4]], dtype='float64') |
||||
|
|
||||
|
matrix = stormpy.build_sparse_matrix(array, row_group_indices=[1, 3]) |
||||
|
|
||||
|
# Check matrix dimension |
||||
|
assert matrix.nr_rows == array.shape[0] |
||||
|
assert matrix.nr_columns == array.shape[1] |
||||
|
assert matrix.nr_entries == 8 |
||||
|
|
||||
|
# Check matrix values |
||||
|
for r in range(array.shape[1]): |
||||
|
row = matrix.get_row(r) |
||||
|
for e in row: |
||||
|
assert (e.value() == array[r, e.column]) |
||||
|
|
||||
|
# Check row groups |
||||
|
assert matrix.get_row_group_start(0) == 1 |
||||
|
assert matrix.get_row_group_end(0) == 3 |
||||
|
|
||||
|
assert matrix.get_row_group_start(1) == 3 |
||||
|
assert matrix.get_row_group_end(1) == 4 |
Write
Preview
Loading…
Cancel
Save
Reference in new issue