You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

271 lines
8.5 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.roomgrid import RoomGrid
  5. from minigrid.core.world_object import Ball, Box, Key
  6. class ObstructedMazeEnv(RoomGrid):
  7. """
  8. ## Description
  9. The agent has to pick up a box which is placed in a corner of a 3x3 maze.
  10. The doors are locked, the keys are hidden in boxes and doors are obstructed
  11. by balls. This environment can be solved without relying on language.
  12. ## Mission Space
  13. "pick up the {COLOR_NAMES[0]} ball"
  14. ## Action Space
  15. | Num | Name | Action |
  16. |-----|--------------|---------------------------|
  17. | 0 | left | Turn left |
  18. | 1 | right | Turn right |
  19. | 2 | forward | Move forward |
  20. | 3 | pickup | Pick up an object |
  21. | 4 | drop | Unused |
  22. | 5 | toggle | Toggle/activate an object |
  23. | 6 | done | Unused |
  24. ## Observation Encoding
  25. - Each tile is encoded as a 3 dimensional tuple:
  26. `(OBJECT_IDX, COLOR_IDX, STATE)`
  27. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  28. [minigrid/minigrid.py](minigrid/minigrid.py)
  29. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  30. ## Rewards
  31. A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
  32. ## Termination
  33. The episode ends if any one of the following conditions is met:
  34. 1. The agent picks up the blue ball.
  35. 2. Timeout (see `max_steps`).
  36. ## Registered Configurations
  37. "NDl" are the number of doors locked.
  38. "h" if the key is hidden in a box.
  39. "b" if the door is obstructed by a ball.
  40. "Q" number of quarters that will have doors and keys out of the 9 that the
  41. map already has.
  42. "Full" 3x3 maze with "h" and "b" options.
  43. "v1" prevents the key from being covered by the blocking ball. Only 2Dlhb, 1Q, 2Q, and Full are
  44. updated to v1. Other configurations won't face this issue because there is no blocking ball (1Dl,
  45. 1Dlh, 2Dl, 2Dlh) or the only blocking ball is added before the key (1Dlhb).
  46. - `MiniGrid-ObstructedMaze-1Dl-v0`
  47. - `MiniGrid-ObstructedMaze-1Dlh-v0`
  48. - `MiniGrid-ObstructedMaze-1Dlhb-v0`
  49. - `MiniGrid-ObstructedMaze-2Dl-v0`
  50. - `MiniGrid-ObstructedMaze-2Dlh-v0`
  51. - `MiniGrid-ObstructedMaze-2Dlhb-v0`
  52. - `MiniGrid-ObstructedMaze-2Dlhb-v1`
  53. - `MiniGrid-ObstructedMaze-1Q-v0`
  54. - `MiniGrid-ObstructedMaze-1Q-v1`
  55. - `MiniGrid-ObstructedMaze-2Q-v0`
  56. - `MiniGrid-ObstructedMaze-2Q-v1`
  57. - `MiniGrid-ObstructedMaze-Full-v0`
  58. - `MiniGrid-ObstructedMaze-Full-v1`
  59. """
  60. def __init__(
  61. self,
  62. num_rows,
  63. num_cols,
  64. num_rooms_visited,
  65. max_steps: int | None = None,
  66. **kwargs,
  67. ):
  68. room_size = 6
  69. if max_steps is None:
  70. max_steps = 4 * num_rooms_visited * room_size**2
  71. mission_space = MissionSpace(
  72. mission_func=self._gen_mission,
  73. ordered_placeholders=[[COLOR_NAMES[0]]],
  74. )
  75. super().__init__(
  76. mission_space=mission_space,
  77. room_size=room_size,
  78. num_rows=num_rows,
  79. num_cols=num_cols,
  80. max_steps=max_steps,
  81. **kwargs,
  82. )
  83. self.obj = Ball() # initialize the obj attribute, that will be changed later on
  84. @staticmethod
  85. def _gen_mission(color: str):
  86. return f"pick up the {color} ball"
  87. def _gen_grid(self, width, height):
  88. super()._gen_grid(width, height)
  89. # Define all possible colors for doors
  90. self.door_colors = self._rand_subset(COLOR_NAMES, len(COLOR_NAMES))
  91. # Define the color of the ball to pick up
  92. self.ball_to_find_color = COLOR_NAMES[0]
  93. # Define the color of the balls that obstruct doors
  94. self.blocking_ball_color = COLOR_NAMES[1]
  95. # Define the color of boxes in which keys are hidden
  96. self.box_color = COLOR_NAMES[2]
  97. self.mission = "pick up the %s ball" % self.ball_to_find_color
  98. def step(self, action):
  99. obs, reward, terminated, truncated, info = super().step(action)
  100. if action == self.actions.pickup:
  101. if self.carrying and self.carrying == self.obj:
  102. reward = self._reward()
  103. terminated = True
  104. return obs, reward, terminated, truncated, info
  105. def add_door(
  106. self,
  107. i,
  108. j,
  109. door_idx=0,
  110. color=None,
  111. locked=False,
  112. key_in_box=False,
  113. blocked=False,
  114. ):
  115. """
  116. Add a door. If the door must be locked, it also adds the key.
  117. If the key must be hidden, it is put in a box. If the door must
  118. be obstructed, it adds a ball in front of the door.
  119. """
  120. door, door_pos = super().add_door(i, j, door_idx, color, locked=locked)
  121. if blocked:
  122. vec = DIR_TO_VEC[door_idx]
  123. blocking_ball = Ball(self.blocking_ball_color) if blocked else None
  124. self.grid.set(door_pos[0] - vec[0], door_pos[1] - vec[1], blocking_ball)
  125. if locked:
  126. obj = Key(door.color)
  127. if key_in_box:
  128. box = Box(self.box_color)
  129. box.contains = obj
  130. obj = box
  131. self.place_in_room(i, j, obj)
  132. return door, door_pos
  133. class ObstructedMaze_1Dlhb(ObstructedMazeEnv):
  134. """
  135. A blue ball is hidden in a 2x1 maze. A locked door separates
  136. rooms. Doors are obstructed by a ball and keys are hidden in boxes.
  137. """
  138. def __init__(self, key_in_box=True, blocked=True, **kwargs):
  139. self.key_in_box = key_in_box
  140. self.blocked = blocked
  141. super().__init__(num_rows=1, num_cols=2, num_rooms_visited=2, **kwargs)
  142. def _gen_grid(self, width, height):
  143. super()._gen_grid(width, height)
  144. self.add_door(
  145. 0,
  146. 0,
  147. door_idx=0,
  148. color=self.door_colors[0],
  149. locked=True,
  150. key_in_box=self.key_in_box,
  151. blocked=self.blocked,
  152. )
  153. self.obj, _ = self.add_object(1, 0, "ball", color=self.ball_to_find_color)
  154. self.place_agent(0, 0)
  155. class ObstructedMaze_Full(ObstructedMazeEnv):
  156. """
  157. A blue ball is hidden in one of the 4 corners of a 3x3 maze. Doors
  158. are locked, doors are obstructed by a ball and keys are hidden in
  159. boxes.
  160. """
  161. def __init__(
  162. self,
  163. agent_room=(1, 1),
  164. key_in_box=True,
  165. blocked=True,
  166. num_quarters=4,
  167. num_rooms_visited=25,
  168. **kwargs,
  169. ):
  170. self.agent_room = agent_room
  171. self.key_in_box = key_in_box
  172. self.blocked = blocked
  173. self.num_quarters = num_quarters
  174. super().__init__(
  175. num_rows=3, num_cols=3, num_rooms_visited=num_rooms_visited, **kwargs
  176. )
  177. def _gen_grid(self, width, height):
  178. super()._gen_grid(width, height)
  179. middle_room = (1, 1)
  180. # Define positions of "side rooms" i.e. rooms that are neither
  181. # corners nor the center.
  182. side_rooms = [(2, 1), (1, 2), (0, 1), (1, 0)][: self.num_quarters]
  183. for i in range(len(side_rooms)):
  184. side_room = side_rooms[i]
  185. # Add a door between the center room and the side room
  186. self.add_door(
  187. *middle_room, door_idx=i, color=self.door_colors[i], locked=False
  188. )
  189. for k in [-1, 1]:
  190. # Add a door to each side of the side room
  191. self.add_door(
  192. *side_room,
  193. locked=True,
  194. door_idx=(i + k) % 4,
  195. color=self.door_colors[(i + k) % len(self.door_colors)],
  196. key_in_box=self.key_in_box,
  197. blocked=self.blocked,
  198. )
  199. corners = [(2, 0), (2, 2), (0, 2), (0, 0)][: self.num_quarters]
  200. ball_room = self._rand_elem(corners)
  201. self.obj, _ = self.add_object(
  202. ball_room[0], ball_room[1], "ball", color=self.ball_to_find_color
  203. )
  204. self.place_agent(*self.agent_room)
  205. class ObstructedMaze_2Dl(ObstructedMaze_Full):
  206. def __init__(self, **kwargs):
  207. super().__init__((2, 1), False, False, 1, 4, **kwargs)
  208. class ObstructedMaze_2Dlh(ObstructedMaze_Full):
  209. def __init__(self, **kwargs):
  210. super().__init__((2, 1), True, False, 1, 4, **kwargs)
  211. class ObstructedMaze_2Dlhb(ObstructedMaze_Full):
  212. def __init__(self, **kwargs):
  213. super().__init__((2, 1), True, True, 1, 4, **kwargs)