You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.8 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.constants import COLOR_NAMES
  3. from minigrid.core.grid import Grid
  4. from minigrid.core.mission import MissionSpace
  5. from minigrid.core.world_object import Ball, Box, Door, Key
  6. from minigrid.minigrid_env import MiniGridEnv
  7. class PlaygroundEnv(MiniGridEnv):
  8. """
  9. Environment with multiple rooms and random objects.
  10. This environment has no specific goals or rewards.
  11. """
  12. def __init__(self, max_steps=100, **kwargs):
  13. mission_space = MissionSpace(mission_func=self._gen_mission)
  14. self.size = 19
  15. super().__init__(
  16. mission_space=mission_space,
  17. width=self.size,
  18. height=self.size,
  19. max_steps=max_steps,
  20. **kwargs,
  21. )
  22. @staticmethod
  23. def _gen_mission():
  24. return ""
  25. def _gen_grid(self, width, height):
  26. # Create the grid
  27. self.grid = Grid(width, height)
  28. # Generate the surrounding walls
  29. self.grid.horz_wall(0, 0)
  30. self.grid.horz_wall(0, height - 1)
  31. self.grid.vert_wall(0, 0)
  32. self.grid.vert_wall(width - 1, 0)
  33. roomW = width // 3
  34. roomH = height // 3
  35. # For each row of rooms
  36. for j in range(0, 3):
  37. # For each column
  38. for i in range(0, 3):
  39. xL = i * roomW
  40. yT = j * roomH
  41. xR = xL + roomW
  42. yB = yT + roomH
  43. # Bottom wall and door
  44. if i + 1 < 3:
  45. self.grid.vert_wall(xR, yT, roomH)
  46. pos = (xR, self._rand_int(yT + 1, yB - 1))
  47. color = self._rand_elem(COLOR_NAMES)
  48. self.grid.set(*pos, Door(color))
  49. # Bottom wall and door
  50. if j + 1 < 3:
  51. self.grid.horz_wall(xL, yB, roomW)
  52. pos = (self._rand_int(xL + 1, xR - 1), yB)
  53. color = self._rand_elem(COLOR_NAMES)
  54. self.grid.set(*pos, Door(color))
  55. # Randomize the player start position and orientation
  56. self.place_agent()
  57. # Place random objects in the world
  58. types = ["key", "ball", "box"]
  59. for i in range(0, 12):
  60. objType = self._rand_elem(types)
  61. objColor = self._rand_elem(COLOR_NAMES)
  62. if objType == "key":
  63. obj = Key(objColor)
  64. elif objType == "ball":
  65. obj = Ball(objColor)
  66. elif objType == "box":
  67. obj = Box(objColor)
  68. else:
  69. raise ValueError(
  70. "{} object type given. Object type can only be of values key, ball and box.".format(
  71. objType
  72. )
  73. )
  74. self.place_obj(obj)
  75. # No explicit mission in this environment
  76. self.mission = ""