You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

190 lines
7.0 KiB

11 months ago
11 months ago
11 months ago
12 months ago
12 months ago
12 months ago
11 months ago
11 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. import re
  11. import sys
  12. import tempfile, datetime, shutil
  13. import gymnasium as gym
  14. from minigrid.core.actions import Actions
  15. from minigrid.core.state import to_state
  16. import os
  17. import time
  18. import argparse
  19. def tic():
  20. #Homemade version of matlab tic and toc functions: https://stackoverflow.com/a/18903019
  21. global startTime_for_tictoc
  22. startTime_for_tictoc = time.time()
  23. def toc():
  24. if 'startTime_for_tictoc' in globals():
  25. print("Elapsed time is " + str(time.time() - startTime_for_tictoc) + " seconds.")
  26. else:
  27. print("Toc: start time not set")
  28. class ShieldingConfig(Enum):
  29. Training = 'training'
  30. Evaluation = 'evaluation'
  31. Disabled = 'none'
  32. Full = 'full'
  33. def __str__(self) -> str:
  34. return self.value
  35. class ShieldHandler(ABC):
  36. def __init__(self) -> None:
  37. pass
  38. def create_shield(self, **kwargs) -> dict:
  39. pass
  40. class MiniGridShieldHandler(ShieldHandler):
  41. def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, prism_config=None, shield_value=0.9, shield_comparison='absolute', nocleanup=False) -> None:
  42. self.tmp_dir_name = f"shielding_files_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
  43. os.mkdir(self.tmp_dir_name)
  44. self.grid_file = self.tmp_dir_name + "/" + grid_file
  45. self.grid_to_prism_binary = grid_to_prism_binary
  46. self.prism_path = self.tmp_dir_name + "/" + prism_path
  47. self.prism_config = prism_config
  48. self.formula = formula
  49. shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
  50. self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comparison, shield_value)
  51. self.nocleanup = nocleanup
  52. def __del__(self):
  53. if not self.nocleanup:
  54. shutil.rmtree(self.tmp_dir_name)
  55. def __export_grid_to_text(self, env):
  56. with open(self.grid_file, "w") as f:
  57. f.write(env.printGrid(init=True))
  58. def __create_prism(self):
  59. if self.prism_config is None:
  60. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
  61. else:
  62. result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  63. assert result == 0, "Prism file could not be generated"
  64. def __create_shield_dict(self):
  65. program = stormpy.parse_prism_program(self.prism_path)
  66. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  67. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  68. options.set_build_state_valuations(True)
  69. options.set_build_choice_labels(True)
  70. options.set_build_all_labels()
  71. print(f"LOG: Starting with explicit model creation...")
  72. tic()
  73. model = stormpy.build_sparse_model_with_options(program, options)
  74. toc()
  75. print(f"LOG: Starting with model checking...")
  76. tic()
  77. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=self.shield_expression)
  78. toc()
  79. assert result.has_shield
  80. shield = result.shield
  81. action_dictionary = dict()
  82. shield_scheduler = shield.construct()
  83. state_valuations = model.state_valuations
  84. choice_labeling = model.choice_labeling
  85. #stormpy.shields.export_shield(model, shield, "current.shield")
  86. for stateID in model.states:
  87. choice = shield_scheduler.get_choice(stateID)
  88. choices = choice.choice_map
  89. state_valuation = state_valuations.get_string(stateID)
  90. ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation))
  91. booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly?
  92. if int(ints.get("previousActionAgent", 3)) != 3:
  93. continue
  94. if int(ints.get("clock", 0)) != 0:
  95. continue
  96. state = to_state(ints, booleans)
  97. action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
  98. return action_dictionary
  99. def create_shield(self, **kwargs):
  100. env = kwargs["env"]
  101. self.__export_grid_to_text(env)
  102. self.__create_prism()
  103. return self.__create_shield_dict()
  104. def expname(args):
  105. return f"{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}_{args.expname_suffix}"
  106. def create_log_dir(args):
  107. log_dir = f"{args.log_dir}/{expname(args)}"
  108. os.makedirs(log_dir, exist_ok=True)
  109. return log_dir
  110. def get_allowed_actions_mask(actions):
  111. action_mask = [0.0] * 3 + [1.0] * 4
  112. actions_labels = [label for labels in actions for label in list(labels)]
  113. for action_label in actions_labels:
  114. if "move" in action_label:
  115. action_mask[2] = 1.0
  116. elif "left" in action_label:
  117. action_mask[0] = 1.0
  118. elif "right" in action_label:
  119. action_mask[1] = 1.0
  120. return action_mask
  121. def common_parser():
  122. parser = argparse.ArgumentParser()
  123. parser.add_argument("--env",
  124. help="gym environment to load",
  125. default="MiniGrid-LavaSlipperyCliff-16x13-v0")
  126. parser.add_argument("--grid_file", default="grid.txt")
  127. parser.add_argument("--prism_output_file", default="grid.prism")
  128. parser.add_argument("--log_dir", default="../log_results/")
  129. parser.add_argument("--formula", default="Pmax=? [G !AgentIsOnLava]")
  130. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  131. parser.add_argument("--steps", default=20_000, type=int)
  132. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  133. parser.add_argument("--prism_config", default=None)
  134. parser.add_argument("--shield_value", default=0.9, type=float)
  135. parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute'])
  136. parser.add_argument("--nocleanup", action=argparse.BooleanOptionalAction)
  137. parser.add_argument("--expname_suffix", default="")
  138. return parser
  139. class MiniWrapper(gym.Wrapper):
  140. def __init__(self, env):
  141. super().__init__(env)
  142. self.env = env
  143. def reset(self, *, seed=None, options=None):
  144. obs, info = self.env.reset(seed=seed, options=options)
  145. return obs.transpose(1,0,2), info
  146. def observations(self, obs):
  147. return obs
  148. def step(self, action):
  149. obs, reward, terminated, truncated, info = self.env.step(action)
  150. return obs.transpose(1,0,2), reward, terminated, truncated, info