You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.4 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.grid import Grid
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
  5. from minigrid.envs.adversaries_base import AdversaryEnv
  6. from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject
  7. import numpy as np
  8. class AdversarySimple(AdversaryEnv):
  9. """
  10. ## Description
  11. ## Registered Configurations
  12. - `MiniGrid-Adv-8x8-v0`
  13. - `MiniGrid-AdvLava-8x8-v0`
  14. - `MiniGrid-AdvSlipperyLava-8x8-v0`
  15. - `MiniGrid-AdvSimple-8x8-v0`
  16. """
  17. def __init__(self, width=7, height=6, generate_wall=True, generate_lava=False, generate_slippery=False ,max_steps: int | None = None, **kwargs):
  18. if max_steps is None:
  19. max_steps = 200
  20. self.generate_wall = generate_wall
  21. super().__init__(
  22. width=width, height=height, max_steps=max_steps, **kwargs
  23. )
  24. def _gen_grid(self, width, height):
  25. self.grid = Grid(width, height)
  26. self.grid.wall_rect(0, 0, width, height)
  27. goal_pos = np.array((width - 2, height - 2))
  28. ball_pos = np.array((width - 3, height - 2))
  29. self.put_obj(Goal(), *goal_pos)
  30. self.put_obj(Ball("yellow"), *ball_pos)
  31. ball = self.grid.get(*ball_pos)
  32. self.adversaries = {}
  33. self.agent_pos = np.array((width - 2, 1))
  34. self.agent_dir = 2
  35. self.grid.horz_wall(2, height - 3)
  36. blue_adv = self.add_adversary(1, 1, "yellow", direction=1, tasks=[GoTo((width - 4, height - 2)),
  37. PickUpObject(ball_pos, ball),
  38. GoTo((1,1)),
  39. PlaceObject((2, 1), ball),
  40. DoNothing(duration=2),
  41. PickUpObject((2, 1), ball),
  42. GoTo((width - 4, height - 2)),
  43. PlaceObject(ball_pos, ball)], repeating=True)
  44. def step(self, action):
  45. obs, reward, terminated, truncated, info = super().step(action)
  46. return obs, reward, terminated, truncated, info