The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

127 lines
4.1 KiB

4 months ago
  1. from __future__ import annotations
  2. from minigrid.core.grid import Grid
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.world_object import Door
  5. from minigrid.minigrid_env import MiniGridEnv
  6. class RedBlueDoorEnv(MiniGridEnv):
  7. """
  8. ## Description
  9. The agent is randomly placed within a room with one red and one blue door
  10. facing opposite directions. The agent has to open the red door and then open
  11. the blue door, in that order. Note that, surprisingly, this environment is
  12. solvable without memory.
  13. ## Mission Space
  14. "open the red door then the blue door"
  15. ## Action Space
  16. | Num | Name | Action |
  17. |-----|--------------|---------------------------|
  18. | 0 | left | Turn left |
  19. | 1 | right | Turn right |
  20. | 2 | forward | Move forward |
  21. | 3 | pickup | Unused |
  22. | 4 | drop | Unused |
  23. | 5 | toggle | Toggle/activate an object |
  24. | 6 | done | Unused |
  25. ## Observation Encoding
  26. - Each tile is encoded as a 3 dimensional tuple:
  27. `(OBJECT_IDX, COLOR_IDX, STATE)`
  28. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  29. [minigrid/minigrid.py](minigrid/minigrid.py)
  30. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  31. ## Rewards
  32. A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
  33. ## Termination
  34. The episode ends if any one of the following conditions is met:
  35. 1. The agent opens the blue door having already opened the red door.
  36. 2. The agent opens the blue door without having opened the red door yet.
  37. 3. Timeout (see `max_steps`).
  38. ## Registered Configurations
  39. - `MiniGrid-RedBlueDoors-6x6-v0`
  40. - `MiniGrid-RedBlueDoors-8x8-v0`
  41. """
  42. def __init__(self, size=8, max_steps: int | None = None, **kwargs):
  43. self.size = size
  44. mission_space = MissionSpace(mission_func=self._gen_mission)
  45. if max_steps is None:
  46. max_steps = 20 * size**2
  47. super().__init__(
  48. mission_space=mission_space,
  49. width=2 * size,
  50. height=size,
  51. max_steps=max_steps,
  52. **kwargs,
  53. )
  54. @staticmethod
  55. def _gen_mission():
  56. return "open the red door then the blue door"
  57. def _gen_grid(self, width, height):
  58. # Create an empty grid
  59. self.grid = Grid(width, height)
  60. # Generate the grid walls
  61. self.grid.wall_rect(0, 0, 2 * self.size, self.size)
  62. self.grid.wall_rect(self.size // 2, 0, self.size, self.size)
  63. # Place the agent in the top-left corner
  64. self.place_agent(top=(self.size // 2, 0), size=(self.size, self.size))
  65. # Add a red door at a random position in the left wall
  66. pos = self._rand_int(1, self.size - 1)
  67. self.red_door = Door("red")
  68. self.grid.set(self.size // 2, pos, self.red_door)
  69. # Add a blue door at a random position in the right wall
  70. pos = self._rand_int(1, self.size - 1)
  71. self.blue_door = Door("blue")
  72. self.grid.set(self.size // 2 + self.size - 1, pos, self.blue_door)
  73. # Generate the mission string
  74. self.mission = "open the red door then the blue door"
  75. def step(self, action):
  76. red_door_opened_before = self.red_door.is_open
  77. blue_door_opened_before = self.blue_door.is_open
  78. obs, reward, terminated, truncated, info = super().step(action)
  79. red_door_opened_after = self.red_door.is_open
  80. blue_door_opened_after = self.blue_door.is_open
  81. if blue_door_opened_after:
  82. if red_door_opened_before:
  83. reward = self._reward()
  84. terminated = True
  85. else:
  86. reward = 0
  87. terminated = True
  88. elif red_door_opened_after:
  89. if blue_door_opened_before:
  90. reward = 0
  91. terminated = True
  92. return obs, reward, terminated, truncated, info