You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

43 lines
1.8 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.grid import Grid
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
  5. from minigrid.envs.adversaries_base import AdversaryEnv
  6. from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject
  7. import numpy as np
  8. class OscillatingAdversaries(AdversaryEnv):
  9. def __init__(self, width=8, height=8, max_steps: int | None = None, **kwargs):
  10. if max_steps is None:
  11. max_steps = 200
  12. super().__init__(
  13. width=width, height=height, max_steps=max_steps, **kwargs
  14. )
  15. def _gen_grid(self, width, height):
  16. assert width >= 8 and height >= 8
  17. self.grid = Grid(width, height)
  18. self.grid.wall_rect(0, 0, width, height)
  19. goal_pos = np.array((int(width/2) - 1, height - 2))
  20. self.put_obj(Goal(), *goal_pos)
  21. self.adversaries = {}
  22. self.agent_pos = np.array((int(width/2), 1))
  23. self.agent_dir = 2
  24. yellow_adv = self.add_adversary(1, 3, "yellow", direction=0, tasks=[GoTo((3, 3)),
  25. GoTo((3, 1)),
  26. GoTo((1, 3))], repeating=True)
  27. green_adv = self.add_adversary(width - 2, 5, "green", direction=3, tasks=[GoTo((width - 2, 3)),
  28. GoTo((width - 4, 3)),
  29. GoTo((width - 4, 5))], repeating=True)
  30. def step(self, action):
  31. obs, reward, terminated, truncated, info = super().step(action)
  32. return obs, reward, terminated, truncated, info