You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

320 lines
11 KiB

11 months ago
11 months ago
12 months ago
12 months ago
11 months ago
11 months ago
12 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. from minigrid.core.actions import Actions
  11. import os
  12. import time
  13. class Action():
  14. def __init__(self, idx, prob=1, labels=[]) -> None:
  15. self.idx = idx
  16. self.prob = prob
  17. self.labels = labels
  18. class ShieldHandler(ABC):
  19. def __init__(self) -> None:
  20. pass
  21. def create_shield(self, **kwargs) -> dict:
  22. pass
  23. class MiniGridShieldHandler(ShieldHandler):
  24. def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
  25. self.grid_file = grid_file
  26. self.grid_to_prism_path = grid_to_prism_path
  27. self.prism_path = prism_path
  28. self.formula = formula
  29. self.prism_config = prism_config
  30. self.shield_value = shield_value
  31. self.shield_comparision = shield_comparision
  32. def __export_grid_to_text(self, env):
  33. f = open(self.grid_file, "w")
  34. f.write(env.printGrid(init=True))
  35. f.close()
  36. def __create_prism(self):
  37. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  38. if self.prism_config is None:
  39. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
  40. else:
  41. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  42. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  43. assert result == 0, "Prism file could not be generated"
  44. f = open(self.prism_path, "a")
  45. f.write("label \"AgentIsInLava\" = AgentIsInLava;")
  46. f.close()
  47. def __create_shield_dict(self):
  48. print(self.prism_path)
  49. program = stormpy.parse_prism_program(self.prism_path)
  50. shield_comp = stormpy.logic.ShieldComparison.RELATIVE
  51. if self.shield_comparision == 'absolute':
  52. shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
  53. shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
  54. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  55. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  56. options.set_build_state_valuations(True)
  57. options.set_build_choice_labels(True)
  58. options.set_build_all_labels()
  59. model = stormpy.build_sparse_model_with_options(program, options)
  60. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
  61. assert result.has_shield
  62. shield = result.shield
  63. action_dictionary = {}
  64. shield_scheduler = shield.construct()
  65. state_valuations = model.state_valuations
  66. choice_labeling = model.choice_labeling
  67. stormpy.shields.export_shield(model, shield, "myshield")
  68. for stateID in model.states:
  69. choice = shield_scheduler.get_choice(stateID)
  70. choices = choice.choice_map
  71. state_valuation = state_valuations.get_string(stateID)
  72. actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
  73. action_dictionary[state_valuation] = actions_to_be_executed
  74. return action_dictionary
  75. def create_shield(self, **kwargs):
  76. env = kwargs["env"]
  77. self.__export_grid_to_text(env)
  78. self.__create_prism()
  79. return self.__create_shield_dict()
  80. def create_shield_query(env):
  81. coordinates = env.env.agent_pos
  82. view_direction = env.env.agent_dir
  83. keys = extract_keys(env)
  84. doors = extract_doors(env)
  85. adversaries = extract_adversaries(env)
  86. if env.carrying:
  87. agent_carrying = F"Agent_is_carrying_object\t"
  88. else:
  89. agent_carrying = "!Agent_is_carrying_object\t"
  90. key_positions = []
  91. agent_key_status = []
  92. for key in keys:
  93. key_color = key[0].color
  94. key_x = key[1]
  95. key_y = key[2]
  96. if env.carrying and env.carrying.type == "key":
  97. agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
  98. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  99. else:
  100. agent_key_text = F"!Agent_has_{key_color}_key\t& "
  101. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  102. key_positions.append(key_position)
  103. agent_key_status.append(agent_key_text)
  104. if key_positions:
  105. key_positions[-1] = key_positions[-1].strip()
  106. door_status = []
  107. for door in doors:
  108. status = ""
  109. if door.is_open:
  110. status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
  111. elif door.is_locked:
  112. status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
  113. else:
  114. status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
  115. door_status.append(status)
  116. adv_status = []
  117. adv_positions = []
  118. for adversary in adversaries:
  119. status = ""
  120. position = ""
  121. if adversary.carrying:
  122. carrying = F"{adversary.name}_is_carrying_object\t"
  123. else:
  124. carrying = F"!{adversary.name}_is_carrying_object\t"
  125. status = F"{carrying}& !{adversary.name}Done\t& "
  126. position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
  127. adv_status.append(status)
  128. adv_positions.append(position)
  129. door_status_text = ""
  130. if door_status:
  131. door_status_text = F"& {''.join(door_status)}\t"
  132. adv_status_text = ""
  133. if adv_status:
  134. adv_status_text = F"& {''.join(adv_status)}"
  135. adv_positions_text = ""
  136. if adv_positions:
  137. adv_positions_text = F"\t& {''.join(adv_positions)}"
  138. key_positions_text = ""
  139. if key_positions:
  140. key_positions_text = F"\t& {''.join(key_positions)}"
  141. move_text = ""
  142. if adversaries:
  143. move_text = F"move=0\t& "
  144. agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
  145. query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
  146. return query
  147. class ShieldingConfig(Enum):
  148. Training = 'training'
  149. Evaluation = 'evaluation'
  150. Disabled = 'none'
  151. Full = 'full'
  152. def __str__(self) -> str:
  153. return self.value
  154. def extract_keys(env):
  155. keys = []
  156. for j in range(env.grid.height):
  157. for i in range(env.grid.width):
  158. obj = env.grid.get(i,j)
  159. if obj and obj.type == "key":
  160. keys.append((obj, i, j))
  161. if env.carrying and env.carrying.type == "key":
  162. keys.append((env.carrying, -1, -1))
  163. # TODO Maybe need to add ordering of keys so it matches the order in the shield
  164. return keys
  165. def extract_doors(env):
  166. doors = []
  167. for j in range(env.grid.height):
  168. for i in range(env.grid.width):
  169. obj = env.grid.get(i,j)
  170. if obj and obj.type == "door":
  171. doors.append(obj)
  172. return doors
  173. def extract_adversaries(env):
  174. adv = []
  175. if not hasattr(env, "adversaries"):
  176. return []
  177. for color, adversary in env.adversaries.items():
  178. adv.append(adversary)
  179. return adv
  180. def create_log_dir(args):
  181. return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
  182. def test_name(args):
  183. return F"{args.expname}"
  184. def get_action_index_mapping(actions):
  185. for action_str in actions:
  186. if not "Agent" in action_str:
  187. continue
  188. if "move" in action_str:
  189. return Actions.forward
  190. elif "left" in action_str:
  191. return Actions.left
  192. elif "right" in action_str:
  193. return Actions.right
  194. elif "pickup" in action_str:
  195. return Actions.pickup
  196. elif "done" in action_str:
  197. return Actions.done
  198. elif "drop" in action_str:
  199. return Actions.drop
  200. elif "toggle" in action_str:
  201. return Actions.toggle
  202. elif "unlock" in action_str:
  203. return Actions.toggle
  204. raise ValueError("No action mapping found")
  205. def parse_arguments(argparse):
  206. parser = argparse.ArgumentParser()
  207. # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
  208. parser.add_argument("--env",
  209. help="gym environment to load",
  210. default="MiniGrid-LavaSlipperyS12-v2",
  211. choices=[
  212. "MiniGrid-Adv-8x8-v0",
  213. "MiniGrid-AdvSimple-8x8-v0",
  214. "MiniGrid-LavaCrossingS9N1-v0",
  215. "MiniGrid-LavaCrossingS9N3-v0",
  216. "MiniGrid-LavaSlipperyS12-v0",
  217. "MiniGrid-LavaSlipperyS12-v1",
  218. "MiniGrid-LavaSlipperyS12-v2",
  219. "MiniGrid-LavaSlipperyS12-v3",
  220. ])
  221. # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
  222. parser.add_argument("--grid_to_prism_binary_path", default="./main")
  223. parser.add_argument("--grid_path", default="grid")
  224. parser.add_argument("--prism_path", default="grid")
  225. parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
  226. parser.add_argument("--log_dir", default="../log_results/")
  227. parser.add_argument("--evaluations", type=int, default=30 )
  228. parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
  229. # parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
  230. parser.add_argument("--workers", type=int, default=1)
  231. parser.add_argument("--num_gpus", type=float, default=0)
  232. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  233. parser.add_argument("--steps", default=20_000, type=int)
  234. parser.add_argument("--expname", default="exp")
  235. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  236. parser.add_argument("--prism_config", default=None)
  237. parser.add_argument("--shield_value", default=0.9, type=float)
  238. parser.add_argument("--probability_displacement", default=1/4, type=float)
  239. parser.add_argument("--probability_intended", default=3/4, type=float)
  240. parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
  241. # parser.add_argument("--random_starts", default=1, type=int)
  242. args = parser.parse_args()
  243. return args