The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

282 lines
9.9 KiB

4 months ago
  1. import stormpy
  2. from configurations import numpy_avail
  3. class TestMatrixBuilder:
  4. def test_matrix_builder(self):
  5. builder = stormpy.SparseMatrixBuilder(force_dimensions=True)
  6. matrix = builder.build()
  7. assert matrix.nr_columns == 0
  8. assert matrix.nr_rows == 0
  9. assert matrix.nr_entries == 0
  10. builder_5x5 = stormpy.SparseMatrixBuilder(5, 5, force_dimensions=False)
  11. builder_5x5.add_next_value(0, 0, 0)
  12. builder_5x5.add_next_value(0, 1, 0.1)
  13. builder_5x5.add_next_value(2, 2, 22)
  14. builder_5x5.add_next_value(2, 3, 23)
  15. assert builder_5x5.get_last_column() == 3
  16. assert builder_5x5.get_last_row() == 2
  17. builder_5x5.add_next_value(3, 2, 32)
  18. builder_5x5.add_next_value(3, 4, 34)
  19. builder_5x5.add_next_value(4, 3, 43)
  20. matrix_5x5 = builder_5x5.build()
  21. assert matrix_5x5.nr_columns == 5
  22. assert matrix_5x5.nr_rows == 5
  23. assert matrix_5x5.nr_entries == 7
  24. for e in matrix_5x5:
  25. assert (e.value() == 0.1 and e.column == 1) or e.value() == 0 or (e.value() > 20 and e.column > 1)
  26. def test_parametric_matrix_builder(self):
  27. builder = stormpy.ParametricSparseMatrixBuilder(force_dimensions=True)
  28. matrix = builder.build()
  29. assert matrix.nr_columns == 0
  30. assert matrix.nr_rows == 0
  31. assert matrix.nr_entries == 0
  32. builder_5x5 = stormpy.ParametricSparseMatrixBuilder(5, 5, force_dimensions=False)
  33. one_pol = stormpy.RationalRF(1)
  34. one_pol = stormpy.FactorizedPolynomial(one_pol)
  35. first_val = stormpy.FactorizedRationalFunction(one_pol)
  36. two_pol = stormpy.RationalRF(2)
  37. two_pol = stormpy.FactorizedPolynomial(two_pol)
  38. sec_val = stormpy.FactorizedRationalFunction(two_pol)
  39. builder_5x5.add_next_value(0, 0, first_val)
  40. builder_5x5.add_next_value(0, 1, first_val)
  41. builder_5x5.add_next_value(2, 2, sec_val)
  42. builder_5x5.add_next_value(2, 3, sec_val)
  43. assert builder_5x5.get_last_column() == 3
  44. assert builder_5x5.get_last_row() == 2
  45. builder_5x5.add_next_value(3, 2, sec_val)
  46. builder_5x5.add_next_value(3, 4, sec_val)
  47. builder_5x5.add_next_value(4, 3, sec_val)
  48. matrix_5x5 = builder_5x5.build()
  49. assert matrix_5x5.nr_columns == 5
  50. assert matrix_5x5.nr_rows == 5
  51. assert matrix_5x5.nr_entries == 7
  52. for e in matrix_5x5:
  53. assert (e.value() == first_val and e.column < 2) or (e.value() == sec_val and e.column > 1)
  54. def test_matrix_replace_columns(self):
  55. builder = stormpy.SparseMatrixBuilder(3, 4, force_dimensions=False)
  56. builder.add_next_value(0, 0, 0)
  57. builder.add_next_value(0, 1, 1)
  58. builder.add_next_value(0, 2, 2)
  59. builder.add_next_value(0, 3, 3)
  60. builder.add_next_value(1, 1, 1)
  61. builder.add_next_value(1, 2, 2)
  62. builder.add_next_value(1, 3, 3)
  63. builder.add_next_value(2, 1, 1)
  64. builder.add_next_value(2, 2, 2)
  65. builder.add_next_value(2, 3, 3)
  66. # replace rows
  67. builder.replace_columns([3, 2, 1], 1)
  68. matrix = builder.build()
  69. assert matrix.nr_entries == 10
  70. # Check if columns where replaced
  71. for e in matrix:
  72. assert (e.value() == 0 and e.column == 0) or (e.value() == 3 and e.column == 1) or (
  73. e.value() == 2 and e.column == 2) or (e.value() == 1 and e.column == 3)
  74. def test_parametric_matrix_replace_columns(self):
  75. builder = stormpy.ParametricSparseMatrixBuilder(3, 4, force_dimensions=False)
  76. one_pol = stormpy.RationalRF(1)
  77. one_pol = stormpy.FactorizedPolynomial(one_pol)
  78. first_val = stormpy.FactorizedRationalFunction(one_pol, one_pol)
  79. two_pol = stormpy.RationalRF(2)
  80. two_pol = stormpy.FactorizedPolynomial(two_pol)
  81. sec_val = stormpy.FactorizedRationalFunction(two_pol, two_pol)
  82. third_val = stormpy.FactorizedRationalFunction(one_pol, two_pol)
  83. builder.add_next_value(0, 1, first_val)
  84. builder.add_next_value(0, 2, sec_val)
  85. builder.add_next_value(0, 3, third_val)
  86. builder.add_next_value(1, 1, first_val)
  87. builder.add_next_value(1, 2, sec_val)
  88. builder.add_next_value(1, 3, third_val)
  89. builder.add_next_value(2, 1, first_val)
  90. builder.add_next_value(2, 2, sec_val)
  91. builder.add_next_value(2, 3, third_val)
  92. # replace rows
  93. builder.replace_columns([3, 2], 2)
  94. matrix = builder.build()
  95. assert matrix.nr_entries == 9
  96. # Check if columns where replaced
  97. for e in matrix:
  98. assert (e.value() == first_val and e.column == 1) or (e.value() == third_val and e.column == 2) or (
  99. e.value() == sec_val and e.column == 3)
  100. def test_matrix_builder_row_grouping(self):
  101. num_rows = 5
  102. builder = stormpy.SparseMatrixBuilder(num_rows, 6, has_custom_row_grouping=True, row_groups=2)
  103. builder.new_row_group(1)
  104. assert builder.get_current_row_group_count() == 1
  105. builder.new_row_group(4)
  106. assert builder.get_current_row_group_count() == 2
  107. matrix = builder.build()
  108. assert matrix.get_row_group_start(0) == 1
  109. assert matrix.get_row_group_end(0) == 4
  110. assert matrix.get_row_group_start(1) == 4
  111. assert matrix.get_row_group_end(1) == 5
  112. def test_parametric_matrix_builder_row_grouping(self):
  113. num_rows = 5
  114. builder = stormpy.ParametricSparseMatrixBuilder(num_rows, 6, has_custom_row_grouping=True, row_groups=2)
  115. builder.new_row_group(1)
  116. assert builder.get_current_row_group_count() == 1
  117. builder.new_row_group(4)
  118. assert builder.get_current_row_group_count() == 2
  119. matrix = builder.build()
  120. assert matrix.get_row_group_start(0) == 1
  121. assert matrix.get_row_group_end(0) == 4
  122. assert matrix.get_row_group_start(1) == 4
  123. assert matrix.get_row_group_end(1) == 5
  124. @numpy_avail
  125. def test_matrix_from_numpy(self):
  126. import numpy as np
  127. array = np.array([[0, 2],
  128. [3, 4],
  129. [0.1, 24],
  130. [-0.3, -4]], dtype='float64')
  131. matrix = stormpy.build_sparse_matrix(array)
  132. # Check matrix dimension
  133. assert matrix.nr_rows == array.shape[0]
  134. assert matrix.nr_columns == array.shape[1]
  135. assert matrix.nr_entries == 8
  136. # Check matrix values
  137. for r in range(array.shape[1]):
  138. row = matrix.get_row(r)
  139. for e in row:
  140. assert (e.value() == array[r, e.column])
  141. @numpy_avail
  142. def test_parametric_matrix_from_numpy(self):
  143. import numpy as np
  144. one_pol = stormpy.RationalRF(1)
  145. one_pol = stormpy.FactorizedPolynomial(one_pol)
  146. first_val = stormpy.FactorizedRationalFunction(one_pol, one_pol)
  147. two_pol = stormpy.RationalRF(2)
  148. two_pol = stormpy.FactorizedPolynomial(two_pol)
  149. sec_val = stormpy.FactorizedRationalFunction(two_pol, two_pol)
  150. third_val = stormpy.FactorizedRationalFunction(one_pol, two_pol)
  151. array = np.array([[sec_val, first_val],
  152. [first_val, sec_val],
  153. [sec_val, sec_val],
  154. [third_val, third_val]])
  155. matrix = stormpy.build_parametric_sparse_matrix(array)
  156. # Check matrix dimension
  157. assert matrix.nr_rows == array.shape[0]
  158. assert matrix.nr_columns == array.shape[1]
  159. assert matrix.nr_entries == 8
  160. # Check matrix values
  161. for r in range(array.shape[1]):
  162. row = matrix.get_row(r)
  163. for e in row:
  164. assert (e.value() == array[r, e.column])
  165. @numpy_avail
  166. def test_matrix_from_numpy_row_grouping(self):
  167. import numpy as np
  168. array = np.array([[0, 2],
  169. [3, 4],
  170. [0.1, 24],
  171. [-0.3, -4]], dtype='float64')
  172. matrix = stormpy.build_sparse_matrix(array, row_group_indices=[1, 3])
  173. # Check matrix dimension
  174. assert matrix.nr_rows == array.shape[0]
  175. assert matrix.nr_columns == array.shape[1]
  176. assert matrix.nr_entries == 8
  177. # Check matrix values
  178. for r in range(array.shape[1]):
  179. row = matrix.get_row(r)
  180. for e in row:
  181. assert (e.value() == array[r, e.column])
  182. # Check row groups
  183. assert matrix.get_row_group_start(0) == 1
  184. assert matrix.get_row_group_end(0) == 3
  185. assert matrix.get_row_group_start(1) == 3
  186. assert matrix.get_row_group_end(1) == 4
  187. @numpy_avail
  188. def test_parametric_matrix_from_numpy_row_grouping(self):
  189. import numpy as np
  190. one_pol = stormpy.RationalRF(1)
  191. one_pol = stormpy.FactorizedPolynomial(one_pol)
  192. first_val = stormpy.FactorizedRationalFunction(one_pol, one_pol)
  193. two_pol = stormpy.RationalRF(2)
  194. two_pol = stormpy.FactorizedPolynomial(two_pol)
  195. sec_val = stormpy.FactorizedRationalFunction(two_pol, two_pol)
  196. third_val = stormpy.FactorizedRationalFunction(one_pol, two_pol)
  197. array = np.array([[sec_val, first_val],
  198. [first_val, sec_val],
  199. [sec_val, sec_val],
  200. [third_val, third_val]])
  201. matrix = stormpy.build_parametric_sparse_matrix(array, row_group_indices=[1, 3])
  202. # Check matrix dimension
  203. assert matrix.nr_rows == array.shape[0]
  204. assert matrix.nr_columns == array.shape[1]
  205. assert matrix.nr_entries == 8
  206. # Check matrix values
  207. for r in range(array.shape[1]):
  208. row = matrix.get_row(r)
  209. for e in row:
  210. assert (e.value() == array[r, e.column])
  211. # Check row groups
  212. assert matrix.get_row_group_start(0) == 1
  213. assert matrix.get_row_group_end(0) == 3
  214. assert matrix.get_row_group_start(1) == 3
  215. assert matrix.get_row_group_end(1) == 4