The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

146 lines
4.4 KiB

4 months ago
  1. import stormpy
  2. import stormpy.logic
  3. from stormpy.storage import BitVector
  4. from stormpy.utility import ShortestPathsGenerator
  5. from stormpy.utility import MatrixFormat
  6. from helpers.helper import get_example_path
  7. import pytest
  8. import math
  9. # this is admittedly slightly overengineered
  10. class ModelWithKnownShortestPaths:
  11. """Knuth's die model with reference kSP methods"""
  12. def __init__(self):
  13. self.target_label = "one"
  14. program_path = get_example_path("dtmc", "die.pm")
  15. raw_formula = "P=? [ F \"" + self.target_label + "\" ]"
  16. program = stormpy.parse_prism_program(program_path)
  17. formulas = stormpy.parse_properties_for_prism_program(raw_formula, program)
  18. self.model = stormpy.build_model(program, formulas)
  19. def probability(self, k):
  20. return (1 / 2) ** ((2 * k) + 1)
  21. def state_set(self, k):
  22. return BitVector(self.model.nr_states, [0, 1, 3, 7])
  23. def path(self, k):
  24. path = [0] + k * [1, 3] + [7]
  25. return list(reversed(path)) # SPG returns traversal from back
  26. @pytest.fixture(scope="module", params=[1, 2, 3, 3000, 42])
  27. def index(request):
  28. return request.param
  29. @pytest.fixture(scope="module")
  30. def model_with_known_shortest_paths():
  31. return ModelWithKnownShortestPaths()
  32. @pytest.fixture(scope="module")
  33. def model(model_with_known_shortest_paths):
  34. return model_with_known_shortest_paths.model
  35. @pytest.fixture(scope="module")
  36. def expected_distance(model_with_known_shortest_paths):
  37. return model_with_known_shortest_paths.probability
  38. @pytest.fixture(scope="module")
  39. def expected_state_set(model_with_known_shortest_paths):
  40. return model_with_known_shortest_paths.state_set
  41. @pytest.fixture(scope="module")
  42. def expected_path(model_with_known_shortest_paths):
  43. return model_with_known_shortest_paths.path
  44. @pytest.fixture(scope="module")
  45. def target_label(model_with_known_shortest_paths):
  46. return model_with_known_shortest_paths.target_label
  47. @pytest.fixture
  48. def state(model):
  49. some_state = 7
  50. assert model.nr_states > some_state, "test model too small"
  51. return some_state
  52. @pytest.fixture
  53. def state_list(model):
  54. some_state_list = [4, 5, 7]
  55. assert model.nr_states > max(some_state_list), "test model too small"
  56. return some_state_list
  57. @pytest.fixture
  58. def state_bitvector(model, state_list):
  59. return BitVector(length=model.nr_states, set_entries=state_list)
  60. @pytest.fixture
  61. def transition_matrix(model):
  62. return model.transition_matrix
  63. @pytest.fixture
  64. def target_prob_map(model, state_list):
  65. return {i: (1.0 if i in state_list else 0.0) for i in range(model.nr_states)}
  66. @pytest.fixture
  67. def target_prob_list(target_prob_map):
  68. return [target_prob_map[i] for i in range(max(target_prob_map.keys()))]
  69. @pytest.fixture
  70. def initial_states(model):
  71. return BitVector(model.nr_states, model.initial_states)
  72. @pytest.fixture
  73. def matrix_format():
  74. return MatrixFormat.Straight
  75. class TestShortestPaths:
  76. def test_spg_ctor_bitvector_target(self, model, state_bitvector):
  77. _ = ShortestPathsGenerator(model, state_bitvector)
  78. def test_spg_ctor_single_state_target(self, model, state):
  79. _ = ShortestPathsGenerator(model, state)
  80. def test_spg_ctor_state_list_target(self, model, state_list):
  81. _ = ShortestPathsGenerator(model, state_list)
  82. def test_spg_ctor_label_target(self, model, target_label):
  83. _ = ShortestPathsGenerator(model, target_label)
  84. def test_spg_ctor_matrix_vector(self, transition_matrix, target_prob_list, initial_states, matrix_format):
  85. _ = ShortestPathsGenerator(transition_matrix, target_prob_list, initial_states, matrix_format)
  86. def test_spg_ctor_matrix_map(self, transition_matrix, target_prob_map, initial_states, matrix_format):
  87. _ = ShortestPathsGenerator(transition_matrix, target_prob_map, initial_states, matrix_format)
  88. def test_spg_distance(self, model, target_label, index, expected_distance):
  89. spg = ShortestPathsGenerator(model, target_label)
  90. assert math.isclose(spg.get_distance(index), expected_distance(index))
  91. def test_spg_state_set(self, model, target_label, index, expected_state_set):
  92. spg = ShortestPathsGenerator(model, target_label)
  93. assert spg.get_states(index) == expected_state_set(index)
  94. def test_spg_state_list(self, model, target_label, index, expected_path):
  95. spg = ShortestPathsGenerator(model, target_label)
  96. assert spg.get_path_as_list(index) == expected_path(index)