You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.5 KiB

11 months ago
11 months ago
11 months ago
12 months ago
12 months ago
12 months ago
11 months ago
11 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. import re
  11. import sys
  12. import tempfile, datetime, shutil
  13. import gymnasium as gym
  14. from minigrid.core.actions import Actions
  15. from minigrid.core.state import to_state
  16. import os
  17. import time
  18. import argparse
  19. def tic():
  20. #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
  21. global startTime_for_tictoc
  22. startTime_for_tictoc = time.time()
  23. def toc():
  24. if 'startTime_for_tictoc' in globals():
  25. print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
  26. else:
  27. print("Toc: start time not set")
  28. class ShieldingConfig(Enum):
  29. Training = 'training'
  30. Evaluation = 'evaluation'
  31. Disabled = 'none'
  32. Full = 'full'
  33. def __str__(self) -> str:
  34. return self.value
  35. def shield_needed(shielding):
  36. return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full]
  37. def shielded_evaluation(shielding):
  38. return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full]
  39. def shielded_training(shielding):
  40. return shielding in [ShieldingConfig.Training, ShieldingConfig.Full]
  41. class ShieldHandler(ABC):
  42. def __init__(self) -> None:
  43. pass
  44. def create_shield(self, **kwargs) -> dict:
  45. pass
  46. class MiniGridShieldHandler(ShieldHandler):
  47. def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False) -> None:
  48. self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
  49. os.mkdir(self.tmp_dir_name)
  50. self.grid_file = self.tmp_dir_name + "/" + grid_file
  51. self.grid_to_prism_binary = grid_to_prism_binary
  52. self.prism_path = self.tmp_dir_name + "/" + prism_path
  53. self.prism_config = prism_config
  54. self.formula = formula
  55. shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
  56. self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
  57. self.nocleanup = nocleanup
  58. def __del__(self):
  59. if not self.nocleanup:
  60. shutil.rmtree(self.tmp_dir_name)
  61. def __export_grid_to_text(self, env):
  62. with open(self.grid_file, "w") as f:
  63. f.write(env.printGrid(init=True))
  64. def __create_prism(self):
  65. if self.prism_config is None:
  66. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
  67. else:
  68. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  69. assert result == 0, "Prism file could not be generated"
  70. def __create_shield_dict(self):
  71. program = stormpy.parse_prism_program(self.prism_path)
  72. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  73. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  74. options.set_build_state_valuations(True)
  75. options.set_build_choice_labels(True)
  76. options.set_build_all_labels()
  77. print(f"LOG: Starting with explicit model creation...")
  78. tic()
  79. model = stormpy.build_sparse_model_with_options(program, options)
  80. toc()
  81. print(f"LOG: Starting with model checking...")
  82. tic()
  83. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
  84. toc()
  85. assert result.has_shield
  86. shield = result.shield
  87. action_dictionary = dict()
  88. shield_scheduler = shield.construct()
  89. state_valuations = model.state_valuations
  90. choice_labeling = model.choice_labeling
  91. #stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/current.shield")
  92. print(f"LOG: Starting to translate shield...")
  93. tic()
  94. for stateID in model.states:
  95. choice = shield_scheduler.get_choice(stateID)
  96. choices = choice.choice_map
  97. state_valuation = state_valuations.get_string(stateID)
  98. ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
  99. booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
  100. booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
  101. if int(ints.get("previousActionAgent", 3)) != 3:
  102. continue
  103. if int(ints.get("clock", 0)) != 0:
  104. continue
  105. state = to_state(ints, booleans)
  106. action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
  107. toc()
  108. return action_dictionary
  109. def create_shield(self, **kwargs):
  110. env = kwargs["env"]
  111. self.__export_grid_to_text(env)
  112. self.__create_prism()
  113. return self.__create_shield_dict()
  114. def expname(args):
  115. return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}"
  116. def create_log_dir(args):
  117. log_dir = f"{args.log_dir}/{expname(args)}"
  118. os.makedirs(log_dir, exist_ok=True)
  119. return log_dir
  120. def get_allowed_actions_mask(actions):
  121. action_mask = [0.0] * 3 + [1.0] * 4
  122. actions_labels = [label for labels in actions for label in list(labels)]
  123. for action_label in actions_labels:
  124. if "move" in action_label:
  125. action_mask[2] = 1.0
  126. elif "left" in action_label:
  127. action_mask[0] = 1.0
  128. elif "right" in action_label:
  129. action_mask[1] = 1.0
  130. return action_mask
  131. def common_parser():
  132. parser = argparse.ArgumentParser()
  133. parser.add_argument("--env",
  134. help="gym environment to load",
  135. default="MiniGrid-LavaSlipperyCliff-16x13-v0")
  136. parser.add_argument("--grid_file", default="grid.txt")
  137. parser.add_argument("--prism_output_file", default="grid.prism")
  138. parser.add_argument("--log_dir", default="../log_results/")
  139. parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
  140. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  141. parser.add_argument("--steps", default=20_000, type=int)
  142. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  143. parser.add_argument("--prism_config", default=None)
  144. parser.add_argument("--shield_value", default=0.9, type=float)
  145. parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
  146. parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction)
  147. parser.add_argument("--expname_suffix", default="")
  148. return parser
  149. class MiniWrapper(gym.Wrapper):
  150. def __init__(self, env):
  151. super().__init__(env)
  152. self.env = env
  153. def reset(self, *, seed=None, options=None):
  154. obs, info = self.env.reset(seed=seed, options=options)
  155. return obs.transpose(1,0,2), info
  156. def observations(self, obs):
  157. return obs
  158. def step(self, action):
  159. obs, reward, terminated, truncated, info = self.env.step(action)
  160. return obs.transpose(1,0,2), reward, terminated, truncated, info