You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
8.5 KiB

2 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from minigrid.core.actions import Actions
  5. from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
  6. from gymnasium.spaces import Dict, Box
  7. from collections import deque
  8. from ray.rllib.utils.numpy import one_hot
  9. from utils import get_action_index_mapping, MiniGridShieldHandler, create_shield_query, ShieldingConfig
  10. class OneHotShieldingWrapper(gym.core.ObservationWrapper):
  11. def __init__(self, env, vector_index, framestack):
  12. super().__init__(env)
  13. self.framestack = framestack
  14. # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
  15. # +4: Direction.
  16. self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
  17. self.init_x = None
  18. self.init_y = None
  19. self.x_positions = []
  20. self.y_positions = []
  21. self.x_y_delta_buffer = deque(maxlen=100)
  22. self.vector_index = vector_index
  23. self.frame_buffer = deque(maxlen=self.framestack)
  24. for _ in range(self.framestack):
  25. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  26. self.observation_space = Dict(
  27. {
  28. "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
  29. "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
  30. }
  31. )
  32. def observation(self, obs):
  33. # Debug output: max-x/y positions to watch exploration progress.
  34. # print(F"Initial observation in Wrapper {obs}")
  35. if self.step_count == 0:
  36. for _ in range(self.framestack):
  37. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  38. if self.vector_index == 0:
  39. if self.x_positions:
  40. max_diff = max(
  41. np.sqrt(
  42. (np.array(self.x_positions) - self.init_x) ** 2
  43. + (np.array(self.y_positions) - self.init_y) ** 2
  44. )
  45. )
  46. self.x_y_delta_buffer.append(max_diff)
  47. print(
  48. "100-average dist travelled={}".format(
  49. np.mean(self.x_y_delta_buffer)
  50. )
  51. )
  52. self.x_positions = []
  53. self.y_positions = []
  54. self.init_x = self.agent_pos[0]
  55. self.init_y = self.agent_pos[1]
  56. self.x_positions.append(self.agent_pos[0])
  57. self.y_positions.append(self.agent_pos[1])
  58. image = obs["data"]
  59. # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
  60. objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
  61. colors = one_hot(image[:, :, 1], depth=len(COLORS))
  62. states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
  63. all_ = np.concatenate([objects, colors, states], -1)
  64. all_flat = np.reshape(all_, (-1,))
  65. direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
  66. single_frame = np.concatenate([all_flat, direction])
  67. self.frame_buffer.append(single_frame)
  68. tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
  69. return tmp
  70. class MiniGridShieldingWrapper(gym.core.Wrapper):
  71. def __init__(self,
  72. env,
  73. shield_creator : MiniGridShieldHandler,
  74. shield_query_creator,
  75. create_shield_at_reset=False,
  76. mask_actions=True):
  77. super(MiniGridShieldingWrapper, self).__init__(env)
  78. self.max_available_actions = env.action_space.n
  79. self.observation_space = Dict(
  80. {
  81. "data": env.observation_space.spaces["image"],
  82. "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
  83. }
  84. )
  85. self.shield_creator = shield_creator
  86. self.create_shield_at_reset = False # TODO
  87. self.shield = shield_creator.create_shield(env=self.env)
  88. self.mask_actions = mask_actions
  89. self.shield_query_creator = shield_query_creator
  90. print(F"Shielding is {self.mask_actions}")
  91. def create_action_mask(self):
  92. if not self.mask_actions:
  93. ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
  94. return ret
  95. cur_pos_str = self.shield_query_creator(self.env)
  96. # Create the mask
  97. # If shield restricts action mask only valid with 1.0
  98. # else set all actions as valid
  99. allowed_actions = []
  100. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  101. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  102. allowed_actions = self.shield[cur_pos_str]
  103. zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
  104. has_allowed_actions = False
  105. for allowed_action in allowed_actions:
  106. index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
  107. if index is None:
  108. assert(False)
  109. allowed = 1.0
  110. has_allowed_actions = True
  111. mask[index] = allowed
  112. else:
  113. for index, x in enumerate(mask):
  114. mask[index] = 1.0
  115. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  116. if front_tile is not None and front_tile.type == "key":
  117. mask[Actions.pickup] = 1.0
  118. if front_tile and front_tile.type == "door":
  119. mask[Actions.toggle] = 1.0
  120. # print(F"Mask is {mask} State: {cur_pos_str}")
  121. return mask
  122. def reset(self, *, seed=None, options=None):
  123. obs, infos = self.env.reset(seed=seed, options=options)
  124. if self.create_shield_at_reset and self.mask_actions:
  125. self.shield = self.shield_creator.create_shield(env=self.env)
  126. mask = self.create_action_mask()
  127. return {
  128. "data": obs["image"],
  129. "action_mask": mask
  130. }, infos
  131. def step(self, action):
  132. orig_obs, rew, done, truncated, info = self.env.step(action)
  133. mask = self.create_action_mask()
  134. obs = {
  135. "data": orig_obs["image"],
  136. "action_mask": mask,
  137. }
  138. return obs, rew, done, truncated, info
  139. def shielding_env_creater(config):
  140. name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
  141. framestack = config.get("framestack", 4)
  142. args = config.get("args", None)
  143. args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
  144. args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
  145. shielding = config.get("shielding", False)
  146. shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
  147. grid_to_prism_path=args.grid_to_prism_binary_path,
  148. prism_path=args.prism_path,
  149. formula=args.formula,
  150. shield_value=args.shield_value,
  151. prism_config=args.prism_config,
  152. shield_comparision=args.shield_comparision)
  153. probability_intended = args.probability_intended
  154. probability_displacement = args.probability_displacement
  155. probability_turn_intended = args.probability_turn_intended
  156. probability_turn_displacement = args.probability_turn_displacement
  157. env = gym.make(name,
  158. randomize_start=True,
  159. probability_intended=probability_intended,
  160. probability_displacement=probability_displacement,
  161. probability_turn_displacement=probability_turn_displacement,
  162. probability_turn_intended=probability_turn_intended)
  163. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
  164. env = OneHotShieldingWrapper(env,
  165. config.vector_index if hasattr(config, "vector_index") else 0,
  166. framestack=framestack
  167. )
  168. return env
  169. def register_minigrid_shielding_env(args):
  170. env_name = "mini-grid-shielding"
  171. register_env(env_name, shielding_env_creater)
  172. ModelCatalog.register_custom_model(
  173. "shielding_model",
  174. TorchActionMaskModel
  175. )