You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

164 lines
5.9 KiB

11 months ago
11 months ago
11 months ago
12 months ago
12 months ago
12 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. import re
  11. import sys
  12. from minigrid.core.actions import Actions
  13. from minigrid.core.state import to_state
  14. import os
  15. import time
  16. import argparse
  17. def tic():
  18. #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
  19. global startTime_for_tictoc
  20. startTime_for_tictoc = time.time()
  21. def toc():
  22. if 'startTime_for_tictoc' in globals():
  23. print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
  24. else:
  25. print("Toc: start time not set")
  26. class ShieldingConfig(Enum):
  27. Training = 'training'
  28. Evaluation = 'evaluation'
  29. Disabled = 'none'
  30. Full = 'full'
  31. def __str__(self) -> str:
  32. return self.value
  33. class ShieldHandler(ABC):
  34. def __init__(self) -> None:
  35. pass
  36. def create_shield(self, **kwargs) -> dict:
  37. pass
  38. class MiniGridShieldHandler(ShieldHandler):
  39. def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute') -> None:
  40. self.grid_file = grid_file
  41. self.grid_to_prism_binary = grid_to_prism_binary
  42. self.prism_path = prism_path
  43. self.prism_config = prism_config
  44. self.formula = formula
  45. shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
  46. self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
  47. def __export_grid_to_text(self, env):
  48. f = open(self.grid_file, "w")
  49. f.write(env.printGrid(init=True))
  50. f.close()
  51. def __create_prism(self):
  52. if self.prism_config is None:
  53. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
  54. else:
  55. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  56. assert result == 0, "Prism file could not be generated"
  57. def __create_shield_dict(self):
  58. program = stormpy.parse_prism_program(self.prism_path)
  59. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  60. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  61. options.set_build_state_valuations(True)
  62. options.set_build_choice_labels(True)
  63. options.set_build_all_labels()
  64. print(f"LOG: Starting with explicit model creation...")
  65. tic()
  66. model = stormpy.build_sparse_model_with_options(program, options)
  67. toc()
  68. print(f"LOG: Starting with model checking...")
  69. tic()
  70. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
  71. toc()
  72. assert result.has_shield
  73. shield = result.shield
  74. action_dictionary = dict()
  75. shield_scheduler = shield.construct()
  76. state_valuations = model.state_valuations
  77. choice_labeling = model.choice_labeling
  78. #stormpy.shields.export_shield(model, shield, "current.shield")
  79. for stateID in model.states:
  80. choice = shield_scheduler.get_choice(stateID)
  81. choices = choice.choice_map
  82. state_valuation = state_valuations.get_string(stateID)
  83. ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation))
  84. booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly?
  85. if int(ints.get("previousActionAgent", 3)) != 3:
  86. continue
  87. if int(ints.get("clock", 0)) != 0:
  88. continue
  89. state = to_state(ints, booleans)
  90. action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
  91. return action_dictionary
  92. def create_shield(self, **kwargs):
  93. env = kwargs["env"]
  94. self.__export_grid_to_text(env)
  95. self.__create_prism()
  96. return self.__create_shield_dict()
  97. def create_log_dir(args):
  98. return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}"
  99. def test_name(args):
  100. return F"{args.expname}"
  101. def get_allowed_actions_mask(actions):
  102. action_mask = [0.0] * 3 + [1.0] * 4
  103. actions_labels = [label for labels in actions for label in list(labels)]
  104. for action_label in actions_labels:
  105. if "move" in action_label:
  106. action_mask[2] = 1.0
  107. elif "left" in action_label:
  108. action_mask[0] = 1.0
  109. elif "right" in action_label:
  110. action_mask[1] = 1.0
  111. return action_mask
  112. def common_parser():
  113. parser = argparse.ArgumentParser()
  114. parser.add_argument("--env",
  115. help="gym environment to load",
  116. default="MiniGrid-LavaSlipperyCliff-16x12-v0")
  117. parser.add_argument("--grid_file", default="grid.txt")
  118. parser.add_argument("--prism_output_file", default="grid.prism")
  119. parser.add_argument("--log_dir", default="../log_results/")
  120. parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
  121. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  122. parser.add_argument("--steps", default=20_000, type=int)
  123. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  124. parser.add_argument("--prism_config", default=None)
  125. parser.add_argument("--shield_value", default=0.9, type=float)
  126. parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
  127. return parser