You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

67 lines
2.3 KiB

  1. import stormpy.utility
  2. from . import storage
  3. from .storage import *
  4. def build_sparse_matrix(array, row_group_indices=[]):
  5. """
  6. Build a sparse matrix from numpy array.
  7. :param numpy array: The array.
  8. :param List[double] row_group_indices: List containing the starting row of each row group in ascending order.
  9. :return: Sparse matrix.
  10. """
  11. num_row = array.shape[0]
  12. num_col = array.shape[1]
  13. len_group_indices = len(row_group_indices)
  14. if len_group_indices > 0:
  15. builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col, has_custom_row_grouping=True,
  16. row_groups=len_group_indices)
  17. else:
  18. builder = storage.SparseMatrixBuilder(rows=num_row, columns=num_col)
  19. row_group_index = 0
  20. for r in range(num_row):
  21. # check whether to start a custom row group
  22. if row_group_index < len_group_indices and r == row_group_indices[row_group_index]:
  23. builder.new_row_group(r)
  24. row_group_index += 1
  25. # insert values of the current row
  26. for c in range(num_col):
  27. builder.add_next_value(r, c, array[r, c])
  28. return builder.build()
  29. def build_parametric_sparse_matrix(array, row_group_indices=[]):
  30. """
  31. Build a sparse matrix from numpy array.
  32. :param numpy array: The array.
  33. :param List[double] row_group_indices: List containing the starting row of each row group in ascending order.
  34. :return: Parametric sparse matrix.
  35. """
  36. num_row = array.shape[0]
  37. num_col = array.shape[1]
  38. len_group_indices = len(row_group_indices)
  39. if len_group_indices > 0:
  40. builder = storage.ParametricSparseMatrixBuilder(rows=num_row, columns=num_col, has_custom_row_grouping=True,
  41. row_groups=len_group_indices)
  42. else:
  43. builder = storage.ParametricSparseMatrixBuilder(rows=num_row, columns=num_col)
  44. row_group_index = 0
  45. for r in range(num_row):
  46. # check whether to start a custom row group
  47. if row_group_index < len_group_indices and r == row_group_indices[row_group_index]:
  48. builder.new_row_group(r)
  49. row_group_index += 1
  50. # insert values of the current row
  51. for c in range(num_col):
  52. builder.add_next_value(r, c, array[r, c])
  53. return builder.build()