You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.1 KiB
102 lines
3.1 KiB
import stormpy
|
|
import numpy as np
|
|
from helpers.helper import get_example_path
|
|
|
|
import math
|
|
|
|
|
|
class TestMatrixBuilder:
|
|
def test_matrix_builder(self):
|
|
builder = stormpy.SparseMatrixBuilder(force_dimensions=True)
|
|
matrix = builder.build()
|
|
assert matrix.nr_columns == 0
|
|
assert matrix.nr_rows == 0
|
|
assert matrix.nr_entries == 0
|
|
|
|
builder_5x5 = stormpy.SparseMatrixBuilder(5, 5, force_dimensions=False)
|
|
|
|
builder_5x5.add_next_value(0, 0, 0)
|
|
builder_5x5.add_next_value(0, 2, 1)
|
|
|
|
builder_5x5.add_next_value(2, 0, 4)
|
|
builder_5x5.add_next_value(2, 3, 5)
|
|
|
|
assert builder_5x5.get_last_column() == 3
|
|
assert builder_5x5.get_last_row() == 2
|
|
|
|
builder_5x5.add_next_value(3, 1, 0.5)
|
|
builder_5x5.add_next_value(3, 3, 0)
|
|
|
|
builder_5x5.add_next_value(4, 4, 0.2)
|
|
|
|
matrix_5x5 = builder_5x5.build()
|
|
|
|
assert matrix_5x5.nr_columns == 5
|
|
assert matrix_5x5.nr_rows == 5
|
|
assert matrix_5x5.nr_entries == 7
|
|
|
|
# todo test Replace columns
|
|
# builder_5x5.replace_columns...
|
|
|
|
def test_matrix_builder_row_grouping(self):
|
|
|
|
num_rows = 5
|
|
builder = stormpy.SparseMatrixBuilder(num_rows, 6, has_custom_row_grouping=True, row_groups=2)
|
|
|
|
builder.new_row_group(1)
|
|
assert builder.get_current_row_group_count() == 1
|
|
|
|
builder.new_row_group(4)
|
|
assert builder.get_current_row_group_count() == 2
|
|
matrix = builder.build()
|
|
|
|
assert matrix.get_row_group_start(0) == 1
|
|
assert matrix.get_row_group_end(0) == 4
|
|
|
|
assert matrix.get_row_group_start(1) == 4
|
|
assert matrix.get_row_group_end(1) == 5
|
|
|
|
def test_matrix_from_numpy(self):
|
|
array = np.array([[0, 2],
|
|
[3, 4],
|
|
[0.1, 24],
|
|
[-0.3, -4]], dtype='float64')
|
|
|
|
matrix = stormpy.build_sparse_matrix(array)
|
|
|
|
# Check matrix dimension
|
|
assert matrix.nr_rows == array.shape[0]
|
|
assert matrix.nr_columns == array.shape[1]
|
|
assert matrix.nr_entries == 8
|
|
|
|
# Check matrix values
|
|
for r in range(array.shape[1]):
|
|
row = matrix.get_row(r)
|
|
for e in row:
|
|
assert (e.value() == array[r, e.column])
|
|
|
|
def test_matrix_from_numpy_row_grouping(self):
|
|
array = np.array([[0, 2],
|
|
[3, 4],
|
|
[0.1, 24],
|
|
[-0.3, -4]], dtype='float64')
|
|
|
|
matrix = stormpy.build_sparse_matrix(array, row_group_indices=[1, 3])
|
|
|
|
# Check matrix dimension
|
|
assert matrix.nr_rows == array.shape[0]
|
|
assert matrix.nr_columns == array.shape[1]
|
|
assert matrix.nr_entries == 8
|
|
|
|
# Check matrix values
|
|
for r in range(array.shape[1]):
|
|
row = matrix.get_row(r)
|
|
for e in row:
|
|
assert (e.value() == array[r, e.column])
|
|
|
|
# Check row groups
|
|
assert matrix.get_row_group_start(0) == 1
|
|
assert matrix.get_row_group_end(0) == 3
|
|
|
|
assert matrix.get_row_group_start(1) == 3
|
|
assert matrix.get_row_group_end(1) == 4
|