You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.1 KiB

5 years ago
5 years ago
  1. import stormpy
  2. from stormpy.examples.files import prism_dtmc_die
  3. class _DfsQueue:
  4. def __init__(self):
  5. self.queue = []
  6. self.visited = set()
  7. def push(self, state_id):
  8. if state_id not in self.visited:
  9. self.queue.append(state_id)
  10. self.visited.add(state_id)
  11. def pop(self):
  12. if len(self.queue) > 0:
  13. return self.queue.pop()
  14. return None
  15. def _dfs_explore(program, callback):
  16. generator = stormpy.StateGenerator(program)
  17. queue = _DfsQueue()
  18. current_state_id = generator.load_initial_state()
  19. queue.visited.add(current_state_id)
  20. while True:
  21. callback(current_state_id, generator)
  22. successors = generator.expand()
  23. assert len(successors) <= 1
  24. for choice in successors:
  25. for state_id, _prob in choice.distribution:
  26. queue.push(state_id)
  27. current_state_id = queue.pop()
  28. if current_state_id is None:
  29. break
  30. generator.load(current_state_id)
  31. def _load_program(filename):
  32. program = stormpy.parse_prism_program(filename) # pylint: disable=no-member
  33. program = program.substitute_constants()
  34. expression_parser = stormpy.ExpressionParser(program.expression_manager)
  35. expression_parser.set_identifier_mapping(
  36. {var.name: var.get_expression()
  37. for var in program.variables})
  38. return program, expression_parser
  39. def _find_variable(program, name):
  40. for var in program.variables:
  41. if var.name is name:
  42. return var
  43. return None
  44. def test_knuth_yao_die():
  45. program, expression_parser = _load_program(prism_dtmc_die)
  46. s_variable = _find_variable(program, "s")
  47. upper_bound_invariant = expression_parser.parse("d <= 6")
  48. number_states = 0
  49. def callback(_state_id, generator):
  50. nonlocal number_states
  51. number_states += 1
  52. valuation = generator.current_state_to_valuation()
  53. assert valuation.integer_values[s_variable] <= 7
  54. assert generator.current_state_satisfies(upper_bound_invariant)
  55. _dfs_explore(program, callback)
  56. assert number_states == 13