You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
2.4 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.constants import COLOR_NAMES
  3. from minigrid.core.mission import MissionSpace
  4. from minigrid.core.roomgrid import RoomGrid
  5. from minigrid.envs.adversaries_base import AdversaryEnv
  6. from minigrid.core.tasks import GoTo
  7. from minigrid.core.world_object import Door, Box
  8. class AdversaryDoorPickup(RoomGrid, AdversaryEnv):
  9. def __init__(self, success_reward=1, collision_penalty=-1, dense_reward: bool = False, max_steps: int | None = None, **kwargs):
  10. max_steps = 200
  11. super().__init__(
  12. num_rows=1,
  13. num_cols=2,
  14. room_size=6,
  15. max_steps=max_steps,
  16. **kwargs,
  17. )
  18. self.success_reward = success_reward
  19. self.collision_penalty = collision_penalty
  20. self.dense_reward = dense_reward
  21. def _gen_grid(self, width, height):
  22. super()._gen_grid(width, height)
  23. self.width = width
  24. self.agent_pos = (1, 1)
  25. self.agent_dir = 1
  26. self.put_obj(Door("yellow"), int(width/2), height - 2)
  27. object, _ = self.add_object(1, 0, kind="box")
  28. self.object = object
  29. green_adv = self.add_adversary(int(width/2) - 1, 1, "green", direction=1, tasks=[GoTo((int(width/2) - 1, 4)),
  30. GoTo((1, 4)),
  31. GoTo((1, 1)),
  32. GoTo((int(width/2) - 1, 1))], repeating=True)
  33. def step(self, action):
  34. obs, reward, terminated, truncated, info = super().step(action)
  35. if action == self.actions.pickup:
  36. if self.carrying and self.carrying == self.object:
  37. reward = self.success_reward
  38. terminated = True
  39. if self.dense_reward and action == self.actions.toggle:
  40. fwd_pos = self.front_pos
  41. fwd_cell = self.grid.get(*fwd_pos)
  42. if fwd_cell and fwd_cell.type == "door":
  43. if fwd_cell.is_open:
  44. reward += 0.1
  45. if not fwd_cell.is_open:
  46. reward -= 0.11
  47. if self.dense_reward and self.agent_pos[0] < 7:
  48. reward -= 0.001 * (self.width - self.agent_pos[0])
  49. return obs, reward, terminated, truncated, info