You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
11 KiB

2 months ago
  1. import stormpy
  2. import stormpy.core
  3. import stormpy.simulator
  4. import stormpy.shields
  5. import stormpy.logic
  6. import stormpy.examples
  7. import stormpy.examples.files
  8. from enum import Enum
  9. from abc import ABC
  10. from minigrid.core.actions import Actions
  11. import os
  12. import time
  13. class Action():
  14. def __init__(self, idx, prob=1, labels=[]) -> None:
  15. self.idx = idx
  16. self.prob = prob
  17. self.labels = labels
  18. class ShieldHandler(ABC):
  19. def __init__(self) -> None:
  20. pass
  21. def create_shield(self, **kwargs) -> dict:
  22. pass
  23. class MiniGridShieldHandler(ShieldHandler):
  24. def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None, shield_comparision='relative') -> None:
  25. self.grid_file = grid_file
  26. self.grid_to_prism_path = grid_to_prism_path
  27. self.prism_path = prism_path
  28. self.formula = formula
  29. self.prism_config = prism_config
  30. self.shield_value = shield_value
  31. self.shield_comparision = shield_comparision
  32. def __export_grid_to_text(self, env):
  33. f = open(self.grid_file, "w")
  34. f.write(env.printGrid(init=True))
  35. f.close()
  36. def __create_prism(self):
  37. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  38. if self.prism_config is None:
  39. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}")
  40. else:
  41. result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
  42. # result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml")
  43. assert result == 0, "Prism file could not be generated"
  44. f = open(self.prism_path, "a")
  45. f.close()
  46. def __create_shield_dict(self):
  47. print(self.prism_path)
  48. program = stormpy.parse_prism_program(self.prism_path)
  49. shield_comp = stormpy.logic.ShieldComparison.RELATIVE
  50. if self.shield_comparision == 'absolute':
  51. shield_comp = stormpy.logic.ShieldComparison.ABSOLUTE
  52. shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, shield_comp, self.shield_value)
  53. formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
  54. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  55. options.set_build_state_valuations(True)
  56. options.set_build_choice_labels(True)
  57. options.set_build_all_labels()
  58. model = stormpy.build_sparse_model_with_options(program, options)
  59. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
  60. assert result.has_shield
  61. shield = result.shield
  62. action_dictionary = {}
  63. shield_scheduler = shield.construct()
  64. state_valuations = model.state_valuations
  65. choice_labeling = model.choice_labeling
  66. stormpy.shields.export_shield(model, shield, "myshield")
  67. for stateID in model.states:
  68. choice = shield_scheduler.get_choice(stateID)
  69. choices = choice.choice_map
  70. state_valuation = state_valuations.get_string(stateID)
  71. actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
  72. action_dictionary[state_valuation] = actions_to_be_executed
  73. return action_dictionary
  74. def create_shield(self, **kwargs):
  75. env = kwargs["env"]
  76. self.__export_grid_to_text(env)
  77. self.__create_prism()
  78. return self.__create_shield_dict()
  79. def create_shield_query(env):
  80. coordinates = env.env.agent_pos
  81. view_direction = env.env.agent_dir
  82. keys = extract_keys(env)
  83. doors = extract_doors(env)
  84. adversaries = extract_adversaries(env)
  85. if env.carrying:
  86. agent_carrying = F"Agent_is_carrying_object\t"
  87. else:
  88. agent_carrying = "!Agent_is_carrying_object\t"
  89. key_positions = []
  90. agent_key_status = []
  91. for key in keys:
  92. key_color = key[0].color
  93. key_x = key[1]
  94. key_y = key[2]
  95. if env.carrying and env.carrying.type == "key":
  96. agent_key_text = F"Agent_has_{env.carrying.color}_key\t& "
  97. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  98. else:
  99. agent_key_text = F"!Agent_has_{key_color}_key\t& "
  100. key_position = F"xKey{key_color}={key_x}\t& yKey{key_color}={key_y}\t"
  101. key_positions.append(key_position)
  102. agent_key_status.append(agent_key_text)
  103. if key_positions:
  104. key_positions[-1] = key_positions[-1].strip()
  105. door_status = []
  106. for door in doors:
  107. status = ""
  108. if door.is_open:
  109. status = F"!Door{door.color}locked\t& Door{door.color}open\t&"
  110. elif door.is_locked:
  111. status = F"Door{door.color}locked\t& !Door{door.color}open\t&"
  112. else:
  113. status = F"!Door{door.color}locked\t& !Door{door.color}open\t&"
  114. door_status.append(status)
  115. adv_status = []
  116. adv_positions = []
  117. for adversary in adversaries:
  118. status = ""
  119. position = ""
  120. if adversary.carrying:
  121. carrying = F"{adversary.name}_is_carrying_object\t"
  122. else:
  123. carrying = F"!{adversary.name}_is_carrying_object\t"
  124. status = F"{carrying}& !{adversary.name}Done\t& "
  125. position = F"x{adversary.name}={adversary.cur_pos[1]}\t& y{adversary.name}={adversary.cur_pos[0]}\t& view{adversary.name}={adversary.adversary_dir}"
  126. adv_status.append(status)
  127. adv_positions.append(position)
  128. door_status_text = ""
  129. if door_status:
  130. door_status_text = F"& {''.join(door_status)}\t"
  131. adv_status_text = ""
  132. if adv_status:
  133. adv_status_text = F"& {''.join(adv_status)}"
  134. adv_positions_text = ""
  135. if adv_positions:
  136. adv_positions_text = F"\t& {''.join(adv_positions)}"
  137. key_positions_text = ""
  138. if key_positions:
  139. key_positions_text = F"\t& {''.join(key_positions)}"
  140. move_text = ""
  141. if adversaries:
  142. move_text = F"move=0\t& "
  143. agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}"
  144. query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
  145. return query
  146. class ShieldingConfig(Enum):
  147. Training = 'training'
  148. Evaluation = 'evaluation'
  149. Disabled = 'none'
  150. Full = 'full'
  151. def __str__(self) -> str:
  152. return self.value
  153. def extract_keys(env):
  154. keys = []
  155. for j in range(env.grid.height):
  156. for i in range(env.grid.width):
  157. obj = env.grid.get(i,j)
  158. if obj and obj.type == "key":
  159. keys.append((obj, i, j))
  160. if env.carrying and env.carrying.type == "key":
  161. keys.append((env.carrying, -1, -1))
  162. # TODO Maybe need to add ordering of keys so it matches the order in the shield
  163. return keys
  164. def extract_doors(env):
  165. doors = []
  166. for j in range(env.grid.height):
  167. for i in range(env.grid.width):
  168. obj = env.grid.get(i,j)
  169. if obj and obj.type == "door":
  170. doors.append(obj)
  171. return doors
  172. def extract_adversaries(env):
  173. adv = []
  174. if not hasattr(env, "adversaries") or not env.adversaries:
  175. return []
  176. for color, adversary in env.adversaries.items():
  177. adv.append(adversary)
  178. return adv
  179. def create_log_dir(args):
  180. return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
  181. def test_name(args):
  182. return F"{args.expname}"
  183. def get_action_index_mapping(actions):
  184. for action_str in actions:
  185. if not "Agent" in action_str:
  186. continue
  187. if "move" in action_str:
  188. return Actions.forward
  189. elif "left" in action_str:
  190. return Actions.left
  191. elif "right" in action_str:
  192. return Actions.right
  193. elif "pickup" in action_str:
  194. return Actions.pickup
  195. elif "done" in action_str:
  196. return Actions.done
  197. elif "drop" in action_str:
  198. return Actions.drop
  199. elif "toggle" in action_str:
  200. return Actions.toggle
  201. elif "unlock" in action_str:
  202. return Actions.toggle
  203. raise ValueError("No action mapping found")
  204. def parse_arguments(argparse):
  205. parser = argparse.ArgumentParser()
  206. # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
  207. parser.add_argument("--env",
  208. help="gym environment to load",
  209. default="MiniGrid-LavaSlipperyCliffS12-v2",
  210. choices=[
  211. "MiniGrid-Adv-8x8-v0",
  212. "MiniGrid-AdvSimple-8x8-v0",
  213. "MiniGrid-LavaCrossingS9N1-v0",
  214. "MiniGrid-LavaCrossingS9N3-v0",
  215. "MiniGrid-LavaSlipperyCliffS12-v0",
  216. "MiniGrid-LavaFaultyS12-30-v0",
  217. ])
  218. # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
  219. parser.add_argument("--grid_to_prism_binary_path", default="./main")
  220. parser.add_argument("--grid_path", default="grid")
  221. parser.add_argument("--prism_path", default="grid")
  222. parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
  223. parser.add_argument("--log_dir", default="../log_results/")
  224. parser.add_argument("--evaluations", type=int, default=30 )
  225. parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
  226. # parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
  227. parser.add_argument("--workers", type=int, default=1)
  228. parser.add_argument("--num_gpus", type=float, default=0)
  229. parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
  230. parser.add_argument("--steps", default=20_000, type=int)
  231. parser.add_argument("--expname", default="exp")
  232. parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
  233. parser.add_argument("--prism_config", default=None)
  234. parser.add_argument("--shield_value", default=0.9, type=float)
  235. parser.add_argument("--probability_displacement", default=1/4, type=float)
  236. parser.add_argument("--probability_intended", default=3/4, type=float)
  237. parser.add_argument("--probability_turn_displacement", default=0/4, type=float)
  238. parser.add_argument("--probability_turn_intended", default=4/4, type=float)
  239. parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
  240. # parser.add_argument("--random_starts", default=1, type=int)
  241. args = parser.parse_args()
  242. return args