You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
4.8 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.grid import Grid
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
  5. from minigrid.minigrid_env import MiniGridEnv, is_slippery
  6. from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject
  7. import numpy as np
  8. class AdversaryEnv(MiniGridEnv):
  9. """
  10. ## Description
  11. """
  12. def __init__(self, width=7, height=6, generate_wall=True, generate_lava=False, generate_slippery=False ,max_steps: int | None = None, **kwargs):
  13. if max_steps is None:
  14. max_steps = 10 * (width * height)**2
  15. mission_space = MissionSpace(mission_func=self._gen_mission)
  16. self.collision_penalty = -1
  17. super().__init__(
  18. mission_space=mission_space, width=width, height=height, max_steps=max_steps, **kwargs
  19. )
  20. @staticmethod
  21. def _gen_mission():
  22. return "Finish your task while avoiding the adversaries"
  23. def _gen_grid(self, width, height):
  24. self.grid = Grid(width, height)
  25. self.grid.wall_rect(0, 0, width, height)
  26. def step(self, action):
  27. delete_list = list()
  28. for position, box in self.background_tiles.items():
  29. if self.grid.get(*position) is None:
  30. self.grid.set(*position, box)
  31. self.grid.set_background(*position, None)
  32. delete_list.append(tuple(position))
  33. for position in delete_list:
  34. del self.background_tiles[position]
  35. obs, reward, terminated, truncated, info = super().step(action)
  36. agent_pos = self.agent_pos
  37. adv_penalty = 0
  38. if not terminated:
  39. for adversary in self.adversaries.values():
  40. collided = self.move_adversary(adversary, agent_pos)
  41. self.trajectory.append((adversary.color, adversary.adversary_pos, adversary.adversary_dir))
  42. if collided:
  43. terminated = True
  44. info["collision"] = True
  45. try:
  46. reward = self.collision_penalty
  47. except e:
  48. reward = -1
  49. return obs, reward, terminated, truncated, info
  50. def move_adversary(self, adversary, agent_pos):
  51. # fetch current location and forward location
  52. cur_pos = adversary.adversary_pos
  53. current_cell = self.grid.get(*adversary.adversary_pos)
  54. fwd_pos = cur_pos + adversary.dir_vec()
  55. fwd_cell = self.grid.get(*fwd_pos)
  56. collision = False
  57. need_position_update = False
  58. action = adversary.get_action(self)
  59. if action == self.actions.forward and is_slippery(current_cell):
  60. probabilities = current_cell.get_probabilities(adversary.adversary_dir)
  61. possible_fwd_pos, prob = self.get_neighbours_prob(adversary.adversary_pos, probabilities)
  62. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  63. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  64. fwd_cell = self.grid.get(*fwd_pos)
  65. need_position_update = True
  66. if action == self.actions.left:
  67. adversary.adversary_dir -= 1
  68. if adversary.adversary_dir < 0:
  69. adversary.adversary_dir += 4
  70. # Rotate right
  71. elif action == self.actions.right:
  72. adversary.adversary_dir = (adversary.adversary_dir + 1) % 4
  73. # Move forward
  74. elif action == self.actions.forward:
  75. if fwd_pos[0] == agent_pos[0] and fwd_pos[1] == agent_pos[1]:
  76. collision = True
  77. if fwd_cell is None or fwd_cell.can_overlap():
  78. adversary.adversary_pos = tuple(fwd_pos)
  79. # Pick up an object
  80. elif action == self.actions.pickup:
  81. if fwd_cell and fwd_cell.can_pickup():
  82. if adversary.carrying is None:
  83. adversary.carrying = fwd_cell
  84. adversary.carrying.cur_pos = np.array([-1, -1])
  85. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  86. # Drop an object
  87. elif action == self.actions.drop:
  88. if not fwd_cell and adversary.carrying:
  89. self.grid.set(fwd_pos[0], fwd_pos[1], adversary.carrying)
  90. adversary.carrying.cur_pos = fwd_pos
  91. adversary.carrying = None
  92. # Toggle/activate an object
  93. elif action == self.actions.toggle:
  94. if fwd_cell:
  95. fwd_cell.toggle(self, fwd_pos)
  96. # Done action (not used by default)
  97. elif action == self.actions.done:
  98. pass
  99. else:
  100. raise ValueError(f"Unknown action: {action}")
  101. if need_position_update and (fwd_cell is None or fwd_cell.can_overlap()):
  102. adversary.adversary_pos = tuple(fwd_pos)
  103. return collision