The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

76 lines
2.2 KiB

4 months ago
  1. import stormpy
  2. from stormpy.examples.files import prism_dtmc_die
  3. import pytest
  4. class _DfsQueue:
  5. def __init__(self):
  6. self.queue = []
  7. self.visited = set()
  8. def push(self, state_id):
  9. if state_id not in self.visited:
  10. self.queue.append(state_id)
  11. self.visited.add(state_id)
  12. def pop(self):
  13. if len(self.queue) > 0:
  14. return self.queue.pop()
  15. return None
  16. def _dfs_explore(program, callback):
  17. generator = stormpy.StateGenerator(program)
  18. queue = _DfsQueue()
  19. current_state_id = generator.load_initial_state()
  20. queue.visited.add(current_state_id)
  21. while True:
  22. callback(current_state_id, generator)
  23. successors = generator.expand()
  24. assert len(successors) <= 1
  25. for choice in successors:
  26. for state_id, _prob in choice.distribution:
  27. queue.push(state_id)
  28. current_state_id = queue.pop()
  29. if current_state_id is None:
  30. break
  31. generator.load(current_state_id)
  32. def _load_program(filename):
  33. program = stormpy.parse_prism_program(filename) # pylint: disable=no-member
  34. program = program.substitute_constants()
  35. expression_parser = stormpy.ExpressionParser(program.expression_manager)
  36. expression_parser.set_identifier_mapping(
  37. {var.name: var.get_expression()
  38. for var in program.variables})
  39. return program, expression_parser
  40. def _find_variable(program, name):
  41. for var in program.variables:
  42. if var.name is name:
  43. return var
  44. return None
  45. @pytest.mark.skipif(True, reason="State generation is broken")
  46. def test_knuth_yao_die():
  47. program, expression_parser = _load_program(prism_dtmc_die)
  48. s_variable = _find_variable(program, "s")
  49. upper_bound_invariant = expression_parser.parse("d <= 6")
  50. number_states = 0
  51. def callback(_state_id, generator):
  52. nonlocal number_states
  53. number_states += 1
  54. valuation = generator.current_state_to_valuation()
  55. assert valuation.integer_values[s_variable] <= 7
  56. assert generator.current_state_satisfies(upper_bound_invariant)
  57. _dfs_explore(program, callback)
  58. assert number_states == 13