You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
3.8 KiB

3 months ago
  1. from __future__ import annotations
  2. from minigrid.core.constants import COLOR_NAMES
  3. from minigrid.core.grid import Grid
  4. from minigrid.core.mission import MissionSpace
  5. from minigrid.core.world_object import (
  6. Ball,
  7. Box,
  8. Key,
  9. Slippery,
  10. SlipperyEast,
  11. SlipperySouth,
  12. SlipperyNorth,
  13. SlipperyWest,
  14. Lava,
  15. Goal,
  16. Point
  17. )
  18. from minigrid.minigrid_env import MiniGridEnv
  19. import numpy as np
  20. import random
  21. class Playground(MiniGridEnv):
  22. """
  23. An empty Playground environment for Graz Security Week
  24. """
  25. def __init__(self,
  26. size=12,
  27. width=None,
  28. height=None,
  29. fault_probability=0.0,
  30. per_step_penalty=0.0,
  31. probability_intended=1.0,
  32. probability_turn_intended=1.0,
  33. faulty_behavior=True,
  34. randomize_start=True,
  35. **kwargs):
  36. self.size = size
  37. self.fault_probability = fault_probability
  38. self.faulty_behavior = faulty_behavior
  39. self.previous_action = None
  40. self.per_step_penalty = per_step_penalty
  41. self.randomize_start = randomize_start
  42. self.probability_intended = probability_intended
  43. self.probability_turn_intended = probability_turn_intended
  44. if width is not None and height is not None:
  45. self.width = width
  46. self.height = height
  47. else:
  48. self.width = size
  49. self.height = size
  50. mission_space = MissionSpace(mission_func=lambda: "get to the green goal square")
  51. super().__init__(
  52. mission_space=mission_space,
  53. width=self.width,
  54. height=self.height,
  55. max_steps=200,
  56. see_through_walls=False,
  57. **kwargs
  58. )
  59. def fault(self):
  60. return True if random.random() < self.fault_probability else False
  61. def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
  62. if self.step_count > 0 and self.fault():
  63. action = self.previous_action
  64. self.previous_action = action
  65. obs, reward, terminated, trucated, info = super().step(action)
  66. return obs, reward - self.per_step_penalty, terminated, trucated, info
  67. def reset(self, **kwargs) -> tuple[ObsType, dict[str, Any]]:
  68. self.previous_action = None
  69. return super().reset(**kwargs)
  70. def _gen_grid(self, width, height):
  71. assert width >= 5 and height >= 5
  72. # Create an empty grid
  73. self.grid = Grid(width, height)
  74. slippery_north = SlipperyNorth(probability_intended=self.probability_intended, probability_turn_intended=self.probability_turn_intended)
  75. slippery_east = SlipperyEast(probability_intended=self.probability_intended, probability_turn_intended=self.probability_turn_intended)
  76. slippery_south = SlipperySouth(probability_intended=self.probability_intended, probability_turn_intended=self.probability_turn_intended)
  77. slippery_west = SlipperyWest(probability_intended=self.probability_intended, probability_turn_intended=self.probability_turn_intended)
  78. # A rectangular wall around the environment
  79. self.grid.wall_rect(0, 0, width, height)
  80. # Change the goal position:
  81. self.put_obj(Goal(), width - 2, 1)
  82. # TODO: Add walls, pools of lava, etc.
  83. if self.randomize_start:
  84. self.place_agent()
  85. else:
  86. self.agent_pos = np.array((1, height - 2))
  87. self.agent_dir = 3
  88. def disable_random_start(self):
  89. self.randomize_start = False
  90. def printGrid(self, init=False):
  91. grid = super().printGrid(init)
  92. properties_str = ""
  93. if self.faulty_behavior:
  94. properties_str += F"FaultProbability:{self.fault_probability}\n"
  95. return grid + properties_str