You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
6.3 KiB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
  1. import stormpy
  2. import stormpy.logic
  3. from helpers.helper import get_example_path
  4. import pytest
  5. class TestSparseModelComponents:
  6. def test_init_default(self):
  7. components = stormpy.SparseModelComponents()
  8. assert components.state_labeling.get_labels() == set()
  9. assert components.reward_models == {}
  10. assert components.transition_matrix.nr_rows == 0
  11. assert components.transition_matrix.nr_columns == 0
  12. assert components.markovian_states is None
  13. assert components.player1_matrix is None
  14. assert not components.rate_transitions
  15. # todo: ctmc
  16. # todo: ma
  17. # todo mdp
  18. # todo pomdp?
  19. def test_build_dtmc_from_model_components(self):
  20. nr_states = 13
  21. nr_choices = 13
  22. # TransitionMatrix
  23. builder = stormpy.SparseMatrixBuilder(rows=0, columns=0, entries=0, force_dimensions=False,
  24. has_custom_row_grouping=False, row_groups=0)
  25. # Add transitions
  26. builder.add_next_value(0, 1, 0.5)
  27. builder.add_next_value(0, 2, 0.5)
  28. builder.add_next_value(1, 3, 0.5)
  29. builder.add_next_value(1, 4, 0.5)
  30. builder.add_next_value(2, 5, 0.5)
  31. builder.add_next_value(2, 6, 0.5)
  32. builder.add_next_value(3, 7, 0.5)
  33. builder.add_next_value(3, 1, 0.5)
  34. builder.add_next_value(4, 8, 0.5)
  35. builder.add_next_value(4, 9, 0.5)
  36. builder.add_next_value(5, 10, 0.5)
  37. builder.add_next_value(5, 11, 0.5)
  38. builder.add_next_value(6, 2, 0.5)
  39. builder.add_next_value(6, 12, 0.5)
  40. for s in range(7, 13):
  41. builder.add_next_value(s, s, 1)
  42. # Build transition matrix, update number of rows and columns
  43. transition_matrix = builder.build(nr_states, nr_states)
  44. # StateLabeling
  45. state_labeling = stormpy.storage.StateLabeling(nr_states)
  46. state_labels = {'init', 'one', 'two', 'three', 'four', 'five', 'six', 'done', 'deadlock'}
  47. for label in state_labels:
  48. state_labeling.add_label(label)
  49. # Add label to one state
  50. state_labeling.add_label_to_state('init', 0)
  51. state_labeling.add_label_to_state('one', 7)
  52. state_labeling.add_label_to_state('two', 8)
  53. state_labeling.add_label_to_state('three', 9)
  54. state_labeling.add_label_to_state('four', 10)
  55. state_labeling.add_label_to_state('five', 11)
  56. state_labeling.add_label_to_state('six', 12)
  57. # Set the labeling of states given in a bit vector, where length = nr_states
  58. state_labeling.set_states('done', stormpy.BitVector(nr_states, [7, 8, 9, 10, 11, 12]))
  59. # RewardModels
  60. reward_models = {}
  61. # Create a vector representing the state-action rewards
  62. action_reward = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
  63. reward_models['coin_flips'] = stormpy.SparseRewardModel(optional_state_action_reward_vector=action_reward)
  64. # StateValuations
  65. manager = stormpy.ExpressionManager()
  66. var_s = manager.create_integer_variable(name='s')
  67. var_d = manager.create_integer_variable(name='d')
  68. v_builder = stormpy.StateValuationsBuilder()
  69. v_builder.add_variable(var_s)
  70. v_builder.add_variable(var_d)
  71. for s in range(7):
  72. v_builder.add_state(state=s, integer_values=[s, 0])
  73. for s in range(7, 13):
  74. v_builder.add_state(state=s, integer_values=[7, s - 6])
  75. state_valuations = v_builder.build(13)
  76. # todo choice origins:
  77. prism_program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm"))
  78. index_to_identifier_mapping = [1, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 8, 8]
  79. id_to_command_set_mapping = [stormpy.FlatSet() for _ in range(9)]
  80. for i in range(1, 8): # 0: no origin
  81. id_to_command_set_mapping[i].insert(i - 1)
  82. id_to_command_set_mapping[8].insert(7)
  83. #
  84. choice_origins = stormpy.PrismChoiceOrigins(prism_program, index_to_identifier_mapping,
  85. id_to_command_set_mapping)
  86. components = stormpy.SparseModelComponents(transition_matrix=transition_matrix, state_labeling=state_labeling,
  87. reward_models=reward_models)
  88. components.choice_origins = choice_origins
  89. components.state_valuations = state_valuations
  90. dtmc = stormpy.storage.SparseDtmc(components)
  91. assert type(dtmc) is stormpy.SparseDtmc
  92. assert not dtmc.supports_parameters
  93. # test transition matrix
  94. assert dtmc.nr_choices == nr_choices
  95. assert dtmc.nr_states == nr_states
  96. assert dtmc.nr_transitions == 20
  97. assert dtmc.transition_matrix.nr_entries == dtmc.nr_transitions
  98. for e in dtmc.transition_matrix:
  99. assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 6)
  100. for state in dtmc.states:
  101. assert len(state.actions) <= 1
  102. # test state_labeling
  103. assert dtmc.labeling.get_labels() == {'init', 'deadlock', 'done', 'one', 'two', 'three', 'four', 'five', 'six'}
  104. # test reward_models
  105. assert len(dtmc.reward_models) == 1
  106. assert not dtmc.reward_models["coin_flips"].has_state_rewards
  107. assert dtmc.reward_models["coin_flips"].has_state_action_rewards
  108. for reward in dtmc.reward_models["coin_flips"].state_action_rewards:
  109. assert reward == 1.0 or reward == 0.0
  110. assert not dtmc.reward_models["coin_flips"].has_transition_rewards
  111. # test choice_labeling
  112. assert not dtmc.has_choice_labeling()
  113. # test state_valuations
  114. assert dtmc.has_state_valuations()
  115. assert dtmc.state_valuations
  116. value_s = [None] * nr_states
  117. value_d = [None] * nr_states
  118. for s in range(0, dtmc.nr_states):
  119. value_s[s] = dtmc.state_valuations.get_integer_value(s, var_s)
  120. value_d[s] = dtmc.state_valuations.get_integer_value(s, var_d)
  121. assert value_s == [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7]
  122. assert value_d == [0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6]
  123. # todo choice_origins more tests
  124. assert dtmc.has_choice_origins()
  125. assert dtmc.choice_origins is components.choice_origins
  126. assert dtmc.choice_origins.get_number_of_identifiers() == 9