You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

89 lines
2.8 KiB

3 months ago
  1. from __future__ import annotations
  2. from minigrid.core.grid import Grid
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
  5. from minigrid.envs.adversaries_base import AdversaryEnv
  6. from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject, DoRandom, FollowAgent
  7. import numpy as np
  8. class AdversaryDebug(AdversaryEnv):
  9. """
  10. ## Description
  11. ## Registered Configurations
  12. - `MiniGrid-Adv-8x8-v0`
  13. - `MiniGrid-AdvLava-8x8-v0`
  14. - `MiniGrid-AdvSlipperyLava-8x8-v0`
  15. - `MiniGrid-AdvDebug-8x8-v0`
  16. """
  17. def __init__(self, width=7, height=6, generate_wall=False, generate_lava=False, generate_slippery=False ,max_steps: int | None = None, **kwargs):
  18. if max_steps is None:
  19. max_steps = 10 * (width * height)**2
  20. self.generate_wall = generate_wall
  21. self.generate_lava = generate_lava
  22. self.generate_slippery = generate_slippery
  23. super().__init__(
  24. width=width, height=height, max_steps=max_steps, **kwargs
  25. )
  26. def __generate_slippery(self, width, height):
  27. self.put_obj(Lava(), 2, height - 2)
  28. self.put_obj(Lava(), width - 2, height - 4)
  29. self.put_obj(SlipperyEast(), 3, height-2)
  30. self.put_obj(SlipperyWest(), 1, height-2)
  31. self.put_obj(SlipperyNorth(), 2, height-3)
  32. self.put_obj(SlipperyNorth(), width - 2, height-5)
  33. self.put_obj(SlipperyWest(), width - 3, height-4)
  34. self.put_obj(SlipperySouth(), width - 2, height-3)
  35. def __generate_lava(self, width, height):
  36. self.gap_pos = np.array(
  37. (
  38. width // 2,
  39. height // 2,
  40. )
  41. )
  42. self.grid.vert_wall(self.gap_pos[0], 1, height - 2, Lava)
  43. # Put a hole in the wall
  44. self.grid.set(*self.gap_pos, None)
  45. def _gen_grid(self, width, height):
  46. self.grid = Grid(width, height)
  47. self.grid.wall_rect(0, 0, width, height)
  48. self.agent_pos = np.array((1, 1))
  49. self.agent_dir = 1
  50. if self.generate_wall:
  51. wall_length = 3
  52. self.grid.horz_wall(width - wall_length - 2, 2, wall_length)
  53. self.put_obj(SlipperyEast(), width - 3, 1)
  54. self.put_obj(SlipperyNorth(), 3, height-2)
  55. elif self.generate_lava:
  56. self.__generate_lava(width, height)
  57. elif self.generate_slippery:
  58. self.__generate_slippery(width, height)
  59. blue_adv = self.add_adversary(3, 3, "blue", direction=1, tasks=[FollowAgent("red", duration=5), DoRandom(duration=1)], repeating=True)
  60. def step(self, action):
  61. obs, reward, terminated, truncated, info = super().step(action)
  62. return obs, reward, terminated, truncated, info