diff --git a/tests/storage/test_matrix_builder.py b/tests/storage/test_matrix_builder.py index 769876a..ca0eaec 100644 --- a/tests/storage/test_matrix_builder.py +++ b/tests/storage/test_matrix_builder.py @@ -1,8 +1,5 @@ import stormpy import numpy as np -from helpers.helper import get_example_path - -import math class TestMatrixBuilder: @@ -16,18 +13,16 @@ class TestMatrixBuilder: builder_5x5 = stormpy.SparseMatrixBuilder(5, 5, force_dimensions=False) builder_5x5.add_next_value(0, 0, 0) - builder_5x5.add_next_value(0, 2, 1) - - builder_5x5.add_next_value(2, 0, 4) - builder_5x5.add_next_value(2, 3, 5) + builder_5x5.add_next_value(0, 1, 0.1) + builder_5x5.add_next_value(2, 2, 22) + builder_5x5.add_next_value(2, 3, 23) assert builder_5x5.get_last_column() == 3 assert builder_5x5.get_last_row() == 2 - builder_5x5.add_next_value(3, 1, 0.5) - builder_5x5.add_next_value(3, 3, 0) - - builder_5x5.add_next_value(4, 4, 0.2) + builder_5x5.add_next_value(3, 2, 32) + builder_5x5.add_next_value(3, 4, 34) + builder_5x5.add_next_value(4, 3, 43) matrix_5x5 = builder_5x5.build() @@ -35,8 +30,35 @@ class TestMatrixBuilder: assert matrix_5x5.nr_rows == 5 assert matrix_5x5.nr_entries == 7 - # todo test Replace columns - # builder_5x5.replace_columns... + for e in matrix_5x5: + assert (e.value() == 0.1 and e.column == 1) or e.value() == 0 or (e.value() > 20 and e.column > 1) + + def test_matrix_replace_columns(self): + builder = stormpy.SparseMatrixBuilder(3, 4, force_dimensions=False) + + builder.add_next_value(0, 0, 0) + builder.add_next_value(0, 1, 1) + builder.add_next_value(0, 2, 2) + builder.add_next_value(0, 3, 3) + + builder.add_next_value(1, 1, 1) + builder.add_next_value(1, 2, 2) + builder.add_next_value(1, 3, 3) + + builder.add_next_value(2, 1, 1) + builder.add_next_value(2, 2, 2) + builder.add_next_value(2, 3, 3) + + # replace rows + builder.replace_columns([3, 2, 1], 1) + matrix = builder.build() + + assert matrix.nr_entries == 10 + + # Check if columns where replaced + for e in matrix: + assert (e.value() == 0 and e.column == 0) or (e.value() == 3 and e.column == 1) or ( + e.value() == 2 and e.column == 2) or (e.value() == 1 and e.column == 3) def test_matrix_builder_row_grouping(self):