You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
3.8 KiB

2 months ago
  1. from __future__ import annotations
  2. from minigrid.core.constants import COLOR_NAMES
  3. from minigrid.core.grid import Grid
  4. from minigrid.core.mission import MissionSpace
  5. from minigrid.core.world_object import (
  6. Ball,
  7. Box,
  8. Key,
  9. Slippery,
  10. SlipperyEast,
  11. SlipperySouth,
  12. SlipperyNorth,
  13. SlipperyWest,
  14. Lava,
  15. Goal,
  16. Point
  17. )
  18. from minigrid.minigrid_env import MiniGridEnv
  19. import numpy as np
  20. import random
  21. class LavaFaultyEnv(MiniGridEnv):
  22. """
  23. ### Registered Configurations
  24. S: size of map SxS.
  25. V: Version
  26. - `MiniGrid-LavaFaultyS12-v0`
  27. """
  28. def __init__(self,
  29. size=12,
  30. width=None,
  31. height=None,
  32. gap=5,
  33. fault_probability=0.1,
  34. per_step_penalty=0.0,
  35. faulty_behavior=True,
  36. obstacle_type=Lava,
  37. randomize_start=True,
  38. **kwargs):
  39. self.obstacle_type = obstacle_type
  40. self.size = size
  41. self.gap = gap
  42. self.fault_probability = fault_probability
  43. self.faulty_behavior = faulty_behavior
  44. self.previous_action = None
  45. self.per_step_penalty = per_step_penalty
  46. self.randomize_start = randomize_start
  47. if width is not None and height is not None:
  48. self.width = width
  49. self.height = height
  50. else:
  51. self.width = size
  52. self.height = size
  53. if obstacle_type == Lava:
  54. mission_space = MissionSpace(
  55. mission_func=lambda: "avoid the lava and get to the green goal square"
  56. )
  57. else:
  58. mission_space = MissionSpace(
  59. mission_func=lambda: "find the opening and get to the green goal square"
  60. )
  61. super().__init__(
  62. mission_space=mission_space,
  63. width=self.width,
  64. height=self.height,
  65. max_steps=200,
  66. # Set this to True for maximum speed
  67. see_through_walls=False,
  68. **kwargs
  69. )
  70. def fault(self):
  71. return True if random.random() < self.fault_probability else False
  72. def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
  73. if self.step_count > 0 and self.fault():
  74. action = self.previous_action
  75. self.previous_action = action
  76. obs, reward, terminated, trucated, info = super().step(action)
  77. return obs, reward - self.per_step_penalty, terminated, trucated, info
  78. def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
  79. self.previous_action = None
  80. return super().reset(**kwargs)
  81. def _gen_grid(self, width, height):
  82. assert width >= 5 and height >= 5
  83. # Create an empty grid
  84. self.grid = Grid(width, height)
  85. for row in range(1, height - 1):
  86. if row < (height - self.gap):
  87. self.grid.horz_wall(1, row, width - self.gap - row, Lava)
  88. for i, col in enumerate(reversed(range(1, width - 1))):
  89. self.grid.vert_wall(col, self.gap + i, None, Lava)
  90. self.grid.wall_rect(0, 0, width, height)
  91. if self.randomize_start:
  92. self.place_agent()
  93. else:
  94. self.agent_pos = np.array((1, height - 2))
  95. self.agent_dir = 3
  96. self.mission = (
  97. "avoid the lava and get to the green goal square"
  98. if self.obstacle_type == Lava
  99. else "find the opening and get to the green goal square"
  100. )
  101. self.put_obj(Goal(), width - 2, 1)
  102. def disable_random_start(self):
  103. self.randomize_start = False
  104. def printGrid(self, init=False):
  105. grid = super().printGrid(init)
  106. properties_str = ""
  107. if self.faulty_behavior:
  108. properties_str += F"FaultProbability:{self.fault_probability}\n"
  109. return grid + properties_str