You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
9.6 KiB

12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from minigrid.core.actions import Actions
  5. from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
  6. from gymnasium.spaces import Dict, Box
  7. from collections import deque
  8. from ray.rllib.utils.numpy import one_hot
  9. from helpers import get_action_index_mapping
  10. from shieldhandlers import ShieldHandler
  11. class OneHotShieldingWrapper(gym.core.ObservationWrapper):
  12. def __init__(self, env, vector_index, framestack):
  13. super().__init__(env)
  14. self.framestack = framestack
  15. # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
  16. # +4: Direction.
  17. self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
  18. self.init_x = None
  19. self.init_y = None
  20. self.x_positions = []
  21. self.y_positions = []
  22. self.x_y_delta_buffer = deque(maxlen=100)
  23. self.vector_index = vector_index
  24. self.frame_buffer = deque(maxlen=self.framestack)
  25. for _ in range(self.framestack):
  26. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  27. self.observation_space = Dict(
  28. {
  29. "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
  30. "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
  31. }
  32. )
  33. def observation(self, obs):
  34. # Debug output: max-x/y positions to watch exploration progress.
  35. # print(F"Initial observation in Wrapper {obs}")
  36. if self.step_count == 0:
  37. for _ in range(self.framestack):
  38. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  39. if self.vector_index == 0:
  40. if self.x_positions:
  41. max_diff = max(
  42. np.sqrt(
  43. (np.array(self.x_positions) - self.init_x) ** 2
  44. + (np.array(self.y_positions) - self.init_y) ** 2
  45. )
  46. )
  47. self.x_y_delta_buffer.append(max_diff)
  48. print(
  49. "100-average dist travelled={}".format(
  50. np.mean(self.x_y_delta_buffer)
  51. )
  52. )
  53. self.x_positions = []
  54. self.y_positions = []
  55. self.init_x = self.agent_pos[0]
  56. self.init_y = self.agent_pos[1]
  57. self.x_positions.append(self.agent_pos[0])
  58. self.y_positions.append(self.agent_pos[1])
  59. image = obs["data"]
  60. # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
  61. objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
  62. colors = one_hot(image[:, :, 1], depth=len(COLORS))
  63. states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
  64. all_ = np.concatenate([objects, colors, states], -1)
  65. all_flat = np.reshape(all_, (-1,))
  66. direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
  67. single_frame = np.concatenate([all_flat, direction])
  68. self.frame_buffer.append(single_frame)
  69. tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
  70. return tmp
  71. class MiniGridShieldingWrapper(gym.core.Wrapper):
  72. def __init__(self,
  73. env,
  74. shield_creator : ShieldHandler,
  75. shield_query_creator,
  76. create_shield_at_reset=True,
  77. mask_actions=True):
  78. super(MiniGridShieldingWrapper, self).__init__(env)
  79. self.max_available_actions = env.action_space.n
  80. self.observation_space = Dict(
  81. {
  82. "data": env.observation_space.spaces["image"],
  83. "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
  84. }
  85. )
  86. self.shield_creator = shield_creator
  87. self.create_shield_at_reset = create_shield_at_reset
  88. self.shield = shield_creator.create_shield(env=self.env)
  89. self.mask_actions = mask_actions
  90. self.shield_query_creator = shield_query_creator
  91. print(F"Shielding is {self.mask_actions}")
  92. def create_action_mask(self):
  93. # print(F"{self.mask_actions} No shielding")
  94. if not self.mask_actions:
  95. # print("No shielding")
  96. ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
  97. # print(ret)
  98. return ret
  99. cur_pos_str = self.shield_query_creator(self.env)
  100. # print(F"Pos string {cur_pos_str}")
  101. # print(F"Shield {list(self.shield.keys())[0]}")
  102. # print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}")
  103. # Create the mask
  104. # If shield restricts action mask only valid with 1.0
  105. # else set all actions as valid
  106. allowed_actions = []
  107. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  108. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  109. allowed_actions = self.shield[cur_pos_str]
  110. zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
  111. has_allowed_actions = False
  112. for allowed_action in allowed_actions:
  113. index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
  114. if index is None:
  115. print(F"No mapping for action {list(allowed_action.labels)}")
  116. print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}")
  117. assert(False)
  118. allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
  119. if allowed_action.prob == 0 and allowed:
  120. assert False
  121. if allowed:
  122. has_allowed_actions = True
  123. mask[index] = allowed
  124. # if not has_allowed_actions:
  125. # print(F"No action allowed for pos string {cur_pos_str}")
  126. # assert(False)
  127. else:
  128. for index, x in enumerate(mask):
  129. mask[index] = 1.0
  130. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  131. if front_tile is not None and front_tile.type == "key":
  132. mask[Actions.pickup] = 1.0
  133. if front_tile and front_tile.type == "door":
  134. mask[Actions.toggle] = 1.0
  135. # print(F"Mask is {mask} State: {cur_pos_str}")
  136. return mask
  137. def reset(self, *, seed=None, options=None):
  138. obs, infos = self.env.reset(seed=seed, options=options)
  139. if self.create_shield_at_reset and self.mask_actions:
  140. self.shield = self.shield_creator.create_shield(env=self.env)
  141. mask = self.create_action_mask()
  142. return {
  143. "data": obs["image"],
  144. "action_mask": mask
  145. }, infos
  146. def step(self, action):
  147. orig_obs, rew, done, truncated, info = self.env.step(action)
  148. mask = self.create_action_mask()
  149. obs = {
  150. "data": orig_obs["image"],
  151. "action_mask": mask,
  152. }
  153. return obs, rew, done, truncated, info
  154. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  155. def __init__(self,
  156. env,
  157. shield_creator : ShieldHandler,
  158. shield_query_creator,
  159. create_shield_at_reset = True,
  160. mask_actions=True,
  161. ):
  162. super(MiniGridSbShieldingWrapper, self).__init__(env)
  163. self.max_available_actions = env.action_space.n
  164. self.observation_space = env.observation_space.spaces["image"]
  165. self.shield_creator = shield_creator
  166. self.mask_actions = mask_actions
  167. self.shield_query_creator = shield_query_creator
  168. print(F"Shielding is {self.mask_actions}")
  169. def create_action_mask(self):
  170. if not self.mask_actions:
  171. return np.array([1.0] * self.max_available_actions, dtype=np.int8)
  172. cur_pos_str = self.shield_query_creator(self.env)
  173. allowed_actions = []
  174. # Create the mask
  175. # If shield restricts actions, mask only valid actions with 1.0
  176. # else set all actions valid
  177. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  178. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  179. allowed_actions = self.shield[cur_pos_str]
  180. for allowed_action in allowed_actions:
  181. index = get_action_index_mapping(allowed_action.labels)
  182. if index is None:
  183. assert(False)
  184. mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
  185. else:
  186. for index, x in enumerate(mask):
  187. mask[index] = 1.0
  188. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  189. if front_tile and front_tile.type == "door":
  190. mask[Actions.toggle] = 1.0
  191. return mask
  192. def reset(self, *, seed=None, options=None):
  193. obs, infos = self.env.reset(seed=seed, options=options)
  194. shield = self.shield_creator.create_shield(env=self.env)
  195. self.shield = shield
  196. return obs["image"], infos
  197. def step(self, action):
  198. orig_obs, rew, done, truncated, info = self.env.step(action)
  199. obs = orig_obs["image"]
  200. return obs, rew, done, truncated, info