You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

225 lines
8.5 KiB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
10 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. import re
  11. import sys
  12. import tempfile, datetime, shutil
  13. import gymnasium as gym
  14. from minigrid.core.actions import Actions
  15. from minigrid.core.state import to_state
  16. import os
  17. import time
  18. import argparse
  19. def tic():
  20. #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
  21. global startTime_for_tictoc
  22. startTime_for_tictoc = time.time()
  23. def toc():
  24. if 'startTime_for_tictoc' in globals():
  25. print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
  26. else:
  27. print("Toc: start time not set")
  28. class ShieldingConfig(Enum):
  29. Training = 'training'
  30. Evaluation = 'evaluation'
  31. Disabled = 'none'
  32. Full = 'full'
  33. def __str__(self) -> str:
  34. return self.value
  35. def shield_needed(shielding):
  36. return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full]
  37. def shielded_evaluation(shielding):
  38. return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full]
  39. def shielded_training(shielding):
  40. return shielding in [ShieldingConfig.Training, ShieldingConfig.Full]
  41. class ShieldHandler(ABC):
  42. def __init__(self) -> None:
  43. pass
  44. def create_shield(self, **kwargs) -> dict:
  45. pass
  46. class MiniGridShieldHandler(ShieldHandler):
  47. def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False, prism_file=None) -> None:
  48. self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
  49. os.mkdir(self.tmp_dir_name)
  50. self.grid_file = self.tmp_dir_name + "/" + grid_file
  51. self.grid_to_prism_binary = grid_to_prism_binary
  52. self.prism_path = self.tmp_dir_name + "/" + prism_path
  53. self.prism_config = prism_config
  54. self.prism_file = prism_file
  55. self.formula = formula
  56. shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
  57. self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
  58. self.nocleanup = nocleanup
  59. def __del__(self):
  60. if not self.nocleanup:
  61. shutil.rmtree(self.tmp_dir_name)
  62. def __export_grid_to_text(self, env):
  63. with open(self.grid_file, "w") as f:
  64. f.write(env.printGrid(init=True))
  65. def __create_prism(self):
  66. if self.prism_file is not None:
  67. print(self.prism_file)
  68. print(self.prism_path)
  69. shutil.copyfile(self.prism_file, self.prism_path)
  70. return
  71. if self.prism_config is None:
  72. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
  73. else:
  74. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  75. assert result == 0, "Prism file could not be generated"
  76. def __create_shield_dict(self):
  77. program = stormpy.parse_prism_program(self.prism_path)
  78. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  79. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  80. options.set_build_state_valuations(True)
  81. options.set_build_choice_labels(True)
  82. options.set_build_all_labels()
  83. print(f"LOG: Starting with explicit model creation...")
  84. tic()
  85. model = stormpy.build_sparse_model_with_options(program, options)
  86. toc()
  87. print(f"LOG: Starting with model checking...")
  88. tic()
  89. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
  90. toc()
  91. assert result.has_shield
  92. shield = result.shield
  93. action_dictionary = dict()
  94. shield_scheduler = shield.construct()
  95. state_valuations = model.state_valuations
  96. choice_labeling = model.choice_labeling
  97. stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield")
  98. print(f"LOG: Starting to translate shield...")
  99. tic()
  100. for stateID in model.states:
  101. choice = shield_scheduler.get_choice(stateID)
  102. choices = choice.choice_map
  103. state_valuation = state_valuations.get_string(stateID)
  104. #print(state_valuation)
  105. #print(choices)
  106. ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
  107. booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
  108. booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
  109. #print(ints, booleans)
  110. if int(ints.get("previousActionAgent", 3)) != 3:
  111. continue
  112. if int(ints.get("clock", 0)) != 0:
  113. continue
  114. state = to_state(ints, booleans)
  115. #print(f"{state} got added with actions:")
  116. #print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]))
  117. action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
  118. toc()
  119. #print(f"{len(action_dictionary)} states in the shield")
  120. return action_dictionary
  121. def create_shield(self, **kwargs):
  122. env = kwargs["env"]
  123. self.__export_grid_to_text(env)
  124. self.__create_prism()
  125. return self.__create_shield_dict()
  126. def expname(args):
  127. return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}"
  128. def create_log_dir(args):
  129. log_dir = f"{args.log_dir}/{expname(args)}"
  130. os.makedirs(log_dir, exist_ok=True)
  131. return log_dir
  132. def get_allowed_actions_mask(actions):
  133. action_mask = [0.0] * 7
  134. actions_labels = [label for labels in actions for label in list(labels)]
  135. for action_label in actions_labels:
  136. if "move" in action_label:
  137. action_mask[2] = 1.0
  138. elif "left" in action_label:
  139. action_mask[0] = 1.0
  140. elif "right" in action_label:
  141. action_mask[1] = 1.0
  142. elif "pickup" in action_label:
  143. action_mask[3] = 1.0
  144. elif "drop" in action_label:
  145. action_mask[4] = 1.0
  146. elif "toggle" in action_label:
  147. action_mask[5] = 1.0
  148. elif "done" in action_label:
  149. action_mask[6] = 1.0
  150. return action_mask
  151. def common_parser():
  152. parser = argparse.ArgumentParser()
  153. parser.add_argument("--env",
  154. help="gym environment to load",
  155. choices=gym.envs.registry.keys(),
  156. default="MiniGrid-LavaSlipperyCliff-16x13-v0")
  157. parser.add_argument("--grid_file", default="grid.txt")
  158. parser.add_argument("--prism_file", default=None)
  159. parser.add_argument("--prism_output_file", default="grid.prism")
  160. parser.add_argument("--log_dir", default="../log_results/")
  161. parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
  162. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  163. parser.add_argument("--steps", default=20_000, type=int)
  164. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  165. parser.add_argument("--prism_config", default=None)
  166. parser.add_argument("--shield_value", default=0.9, type=float)
  167. parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
  168. parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction)
  169. parser.add_argument("--expname_suffix", default="")
  170. return parser
  171. class MiniWrapper(gym.Wrapper):
  172. def __init__(self, env):
  173. super().__init__(env)
  174. self.env = env
  175. def reset(self, *, seed=None, options=None):
  176. obs, info = self.env.reset(seed=seed, options=options)
  177. return obs.transpose(1,0,2), info
  178. def observations(self, obs):
  179. return obs
  180. def step(self, action):
  181. obs, reward, terminated, truncated, info = self.env.step(action)
  182. return obs.transpose(1,0,2), reward, terminated, truncated, info