You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
5.4 KiB

2 months ago
  1. from __future__ import annotations
  2. import numpy as np
  3. from minigrid.core.actions import Actions
  4. from minigrid.core.grid import Grid
  5. from minigrid.core.mission import MissionSpace
  6. from minigrid.core.world_object import Ball, Key, Wall
  7. from minigrid.minigrid_env import MiniGridEnv
  8. class MemoryEnv(MiniGridEnv):
  9. """
  10. ## Description
  11. This environment is a memory test. The agent starts in a small room where it
  12. sees an object. It then has to go through a narrow hallway which ends in a
  13. split. At each end of the split there is an object, one of which is the same
  14. as the object in the starting room. The agent has to remember the initial
  15. object, and go to the matching object at split.
  16. ## Mission Space
  17. "go to the matching object at the end of the hallway"
  18. ## Action Space
  19. | Num | Name | Action |
  20. |-----|--------------|---------------------------|
  21. | 0 | left | Turn left |
  22. | 1 | right | Turn right |
  23. | 2 | forward | Move forward |
  24. | 3 | pickup | Pick up an object |
  25. | 4 | drop | Unused |
  26. | 5 | toggle | Toggle/activate an object |
  27. | 6 | done | Unused |
  28. ## Observation Encoding
  29. - Each tile is encoded as a 3 dimensional tuple:
  30. `(OBJECT_IDX, COLOR_IDX, STATE)`
  31. - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
  32. [minigrid/minigrid.py](minigrid/minigrid.py)
  33. - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
  34. ## Rewards
  35. A reward of '1 - 0.9 * (step_count / max_steps)' is given for success, and '0' for failure.
  36. ## Termination
  37. The episode ends if any one of the following conditions is met:
  38. 1. The agent reaches the correct matching object.
  39. 2. The agent reaches the wrong matching object.
  40. 3. Timeout (see `max_steps`).
  41. ## Registered Configurations
  42. S: size of map SxS.
  43. - `MiniGrid-MemoryS17Random-v0`
  44. - `MiniGrid-MemoryS13Random-v0`
  45. - `MiniGrid-MemoryS13-v0`
  46. - `MiniGrid-MemoryS11-v0`
  47. """
  48. def __init__(
  49. self, size=8, random_length=False, max_steps: int | None = None, **kwargs
  50. ):
  51. self.size = size
  52. self.random_length = random_length
  53. if max_steps is None:
  54. max_steps = 5 * size**2
  55. mission_space = MissionSpace(mission_func=self._gen_mission)
  56. super().__init__(
  57. mission_space=mission_space,
  58. width=size,
  59. height=size,
  60. # Set this to True for maximum speed
  61. see_through_walls=False,
  62. max_steps=max_steps,
  63. **kwargs,
  64. )
  65. @staticmethod
  66. def _gen_mission():
  67. return "go to the matching object at the end of the hallway"
  68. def _gen_grid(self, width, height):
  69. self.grid = Grid(width, height)
  70. # Generate the surrounding walls
  71. self.grid.horz_wall(0, 0)
  72. self.grid.horz_wall(0, height - 1)
  73. self.grid.vert_wall(0, 0)
  74. self.grid.vert_wall(width - 1, 0)
  75. assert height % 2 == 1
  76. upper_room_wall = height // 2 - 2
  77. lower_room_wall = height // 2 + 2
  78. if self.random_length:
  79. hallway_end = self._rand_int(4, width - 2)
  80. else:
  81. hallway_end = width - 3
  82. # Start room
  83. for i in range(1, 5):
  84. self.grid.set(i, upper_room_wall, Wall())
  85. self.grid.set(i, lower_room_wall, Wall())
  86. self.grid.set(4, upper_room_wall + 1, Wall())
  87. self.grid.set(4, lower_room_wall - 1, Wall())
  88. # Horizontal hallway
  89. for i in range(5, hallway_end):
  90. self.grid.set(i, upper_room_wall + 1, Wall())
  91. self.grid.set(i, lower_room_wall - 1, Wall())
  92. # Vertical hallway
  93. for j in range(0, height):
  94. if j != height // 2:
  95. self.grid.set(hallway_end, j, Wall())
  96. self.grid.set(hallway_end + 2, j, Wall())
  97. # Fix the player's start position and orientation
  98. self.agent_pos = np.array((self._rand_int(1, hallway_end + 1), height // 2))
  99. self.agent_dir = 0
  100. # Place objects
  101. start_room_obj = self._rand_elem([Key, Ball])
  102. self.grid.set(1, height // 2 - 1, start_room_obj("green"))
  103. other_objs = self._rand_elem([[Ball, Key], [Key, Ball]])
  104. pos0 = (hallway_end + 1, height // 2 - 2)
  105. pos1 = (hallway_end + 1, height // 2 + 2)
  106. self.grid.set(*pos0, other_objs[0]("green"))
  107. self.grid.set(*pos1, other_objs[1]("green"))
  108. # Choose the target objects
  109. if start_room_obj == other_objs[0]:
  110. self.success_pos = (pos0[0], pos0[1] + 1)
  111. self.failure_pos = (pos1[0], pos1[1] - 1)
  112. else:
  113. self.success_pos = (pos1[0], pos1[1] - 1)
  114. self.failure_pos = (pos0[0], pos0[1] + 1)
  115. self.mission = "go to the matching object at the end of the hallway"
  116. def step(self, action):
  117. if action == Actions.pickup:
  118. action = Actions.toggle
  119. obs, reward, terminated, truncated, info = super().step(action)
  120. if tuple(self.agent_pos) == self.success_pos:
  121. reward = self._reward()
  122. terminated = True
  123. if tuple(self.agent_pos) == self.failure_pos:
  124. reward = 0
  125. terminated = True
  126. return obs, reward, terminated, truncated, info