You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

273 lines
11 KiB

2 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. from PIL import Image, ImageDraw
  11. import re
  12. import sys
  13. import tempfile, datetime, shutil
  14. import numpy as np
  15. import gymnasium as gym
  16. from minigrid.core.actions import Actions
  17. from minigrid.core.state import to_state, State
  18. import os
  19. import time
  20. import argparse
  21. def tic():
  22. #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
  23. global startTime_for_tictoc
  24. startTime_for_tictoc = time.time()
  25. def toc():
  26. if 'startTime_for_tictoc' in globals():
  27. print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
  28. else:
  29. print("Toc: start time not set")
  30. class ShieldingConfig(Enum):
  31. Training = 'training'
  32. Evaluation = 'evaluation'
  33. Disabled = 'none'
  34. Full = 'full'
  35. def __str__(self) -> str:
  36. return self.value
  37. def shield_needed(shielding):
  38. return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full]
  39. def shielded_evaluation(shielding):
  40. return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full]
  41. def shielded_training(shielding):
  42. return shielding in [ShieldingConfig.Training, ShieldingConfig.Full]
  43. class ShieldHandler(ABC):
  44. def __init__(self) -> None:
  45. pass
  46. def create_shield(self, **kwargs) -> dict:
  47. pass
  48. class MiniGridShieldHandler(ShieldHandler):
  49. def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False, prism_file=None) -> None:
  50. self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
  51. os.mkdir(self.tmp_dir_name)
  52. self.grid_file = self.tmp_dir_name + "/" + grid_file
  53. self.grid_to_prism_binary = grid_to_prism_binary
  54. self.prism_path = self.tmp_dir_name + "/" + prism_path
  55. self.prism_config = prism_config
  56. self.prism_file = prism_file
  57. self.action_dictionary = None
  58. self.formula = formula
  59. shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
  60. self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
  61. self.nocleanup = nocleanup
  62. def __del__(self):
  63. if not self.nocleanup:
  64. shutil.rmtree(self.tmp_dir_name)
  65. def __export_grid_to_text(self, env):
  66. with open(self.grid_file, "w") as f:
  67. f.write(env.printGrid(init=True))
  68. def __create_prism(self):
  69. if self.prism_file is not None:
  70. print(self.prism_file)
  71. print(self.prism_path)
  72. shutil.copyfile(self.prism_file, self.prism_path)
  73. return
  74. if self.prism_config is None:
  75. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
  76. else:
  77. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  78. assert result == 0, "Prism file could not be generated"
  79. def __create_shield_dict(self):
  80. program = stormpy.parse_prism_program(self.prism_path)
  81. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  82. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  83. options.set_build_state_valuations(True)
  84. options.set_build_choice_labels(True)
  85. options.set_build_all_labels()
  86. print(f"LOG: Starting with explicit model creation...")
  87. tic()
  88. model = stormpy.build_sparse_model_with_options(program, options)
  89. toc()
  90. print(f"LOG: Starting with model checking...")
  91. tic()
  92. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
  93. toc()
  94. assert result.has_shield
  95. shield = result.shield
  96. action_dictionary = dict()
  97. shield_scheduler = shield.construct()
  98. state_valuations = model.state_valuations
  99. choice_labeling = model.choice_labeling
  100. if self.nocleanup:
  101. stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield")
  102. print(f"LOG: Starting to translate shield...")
  103. tic()
  104. for stateID in model.states:
  105. choice = shield_scheduler.get_choice(stateID)
  106. choices = choice.choice_map
  107. state_valuation = state_valuations.get_string(stateID)
  108. ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
  109. booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
  110. booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
  111. if int(ints.get("previousActionAgent", 7)) != 7:
  112. continue
  113. if int(ints.get("clock", 0)) != 0:
  114. continue
  115. state = to_state(ints, booleans)
  116. #print(f"{state} got added with actions:")
  117. #print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]))
  118. action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
  119. toc()
  120. #print(f"{len(action_dictionary)} states in the shield")
  121. self.action_dictionary = action_dictionary
  122. # Remove shielding_files_* immediatelly, only to remove clutter for the demo
  123. if not self.nocleanup:
  124. shutil.rmtree(self.tmp_dir_name)
  125. return action_dictionary
  126. def create_shield(self, **kwargs):
  127. if self.action_dictionary is not None:
  128. #print("Returning already calculated shield")
  129. return self.action_dictionary
  130. env = kwargs["env"]
  131. self.__export_grid_to_text(env)
  132. self.__create_prism()
  133. print("Computing new shield")
  134. return self.__create_shield_dict()
  135. def rectangle_for_overlay(x, y, dir, tile_size, width=2, offset=0, thickness=0):
  136. if dir == 0: return (((x+1)*tile_size-width-thickness,y*tile_size+offset), ((x+1)*tile_size,(y+1)*tile_size-offset))
  137. if dir == 1: return ((x*tile_size+offset,(y+1)*tile_size-width-thickness), ((x+1)*tile_size-offset,(y+1)*tile_size))
  138. if dir == 2: return ((x*tile_size,y*tile_size+offset), (x*tile_size+width+thickness,(y+1)*tile_size-offset))
  139. if dir == 3: return ((x*tile_size+offset,y*tile_size), ((x+1)*tile_size-offset,y*tile_size+width+thickness))
  140. def triangle_for_overlay(x,y, dir, tile_size):
  141. offset = tile_size/2
  142. if dir == 0: return [((x+1)*tile_size,y*tile_size), ((x+1)*tile_size,(y+1)*tile_size), ((x+1)*tile_size-offset, y*tile_size+tile_size/2)]
  143. if dir == 1: return [(x*tile_size,(y+1)*tile_size), ((x+1)*tile_size,(y+1)*tile_size), (x*tile_size+tile_size/2, (y+1)*tile_size-offset)]
  144. if dir == 2: return [(x*tile_size,y*tile_size), (x*tile_size,(y+1)*tile_size), (x*tile_size+offset, y*tile_size+tile_size/2)]
  145. if dir == 3: return [(x*tile_size,y*tile_size), ((x+1)*tile_size,y*tile_size), (x*tile_size+tile_size/2, y*tile_size+offset)]
  146. def create_shield_overlay_image(env, shield):
  147. env.reset()
  148. img = Image.fromarray(env.render()).convert("RGBA")
  149. ts = env.tile_size
  150. overlay = Image.new("RGBA", img.size, (255, 255, 255, 0))
  151. draw = ImageDraw.Draw(overlay)
  152. red = (255,0,0,200)
  153. for x in range(0, env.width):
  154. for y in range(0, env.height):
  155. for dir in range(0,4):
  156. try:
  157. if shield[State(x, y, dir, "")][2] <= 0.0:
  158. draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=red)
  159. #else:
  160. # draw.polygon(triangle_for_overlay(x,y,dir,ts), fill=(0, 200, 0, 96))
  161. except KeyError: pass
  162. img = Image.alpha_composite(img, overlay)
  163. img.show()
  164. def expname(args):
  165. return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}"
  166. def create_log_dir(args):
  167. log_dir = f"{args.log_dir}/{expname(args)}"
  168. os.makedirs(log_dir, exist_ok=True)
  169. return log_dir
  170. def get_allowed_actions_mask(actions):
  171. action_mask = [0.0] * 7
  172. actions_labels = [label for labels in actions for label in list(labels)]
  173. for action_label in actions_labels:
  174. if "move" in action_label:
  175. action_mask[2] = 1.0
  176. elif "left" in action_label:
  177. action_mask[0] = 1.0
  178. elif "right" in action_label:
  179. action_mask[1] = 1.0
  180. elif "pickup" in action_label:
  181. action_mask[3] = 1.0
  182. elif "drop" in action_label:
  183. action_mask[4] = 1.0
  184. elif "toggle" in action_label:
  185. action_mask[5] = 1.0
  186. elif "done" in action_label:
  187. action_mask[6] = 1.0
  188. return action_mask
  189. def common_parser():
  190. parser = argparse.ArgumentParser()
  191. parser.add_argument("--env",
  192. help="gym environment to load",
  193. choices=gym.envs.registry.keys(),
  194. default="MiniGrid-LavaSlipperyCliff-16x13-v0")
  195. parser.add_argument("--grid_file", default="grid.txt")
  196. parser.add_argument("--prism_file", default=None)
  197. parser.add_argument("--prism_output_file", default="grid.prism")
  198. parser.add_argument("--log_dir", default="../log_results/")
  199. parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
  200. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  201. parser.add_argument("--steps", default=20_000, type=int)
  202. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  203. parser.add_argument("--prism_config", default=None)
  204. parser.add_argument("--shield_value", default=0.9, type=float)
  205. parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
  206. parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction)
  207. parser.add_argument("--expname_suffix", default="")
  208. return parser
  209. class MiniWrapper(gym.Wrapper):
  210. def __init__(self, env):
  211. super().__init__(env)
  212. self.env = env
  213. def reset(self, *, seed=None, options=None):
  214. obs, info = self.env.reset(seed=seed, options=options)
  215. return obs.transpose(1,0,2), info
  216. def observations(self, obs):
  217. return obs
  218. def step(self, action):
  219. obs, reward, terminated, truncated, info = self.env.step(action)
  220. return obs.transpose(1,0,2), reward, terminated, truncated, info