You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
4.7 KiB

3 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from moviepy.editor import ImageSequenceClip
  5. from utils import MiniGridShieldHandler, common_parser
  6. from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
  7. from stable_baselines3.common.logger import Image
  8. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  9. def __init__(self,
  10. env,
  11. shield_handler : MiniGridShieldHandler,
  12. create_shield_at_reset = False,
  13. ):
  14. super().__init__(env)
  15. self.shield_handler = shield_handler
  16. self.create_shield_at_reset = create_shield_at_reset
  17. shield = self.shield_handler.create_shield(env=self.env)
  18. self.shield = shield
  19. def create_action_mask(self):
  20. try:
  21. return self.shield[self.env.get_symbolic_state()]
  22. except:
  23. return [0.0] * 3 + [0.0] * 4
  24. def reset(self, *, seed=None, options=None):
  25. obs, infos = self.env.reset(seed=seed, options=options)
  26. if self.create_shield_at_reset:
  27. shield = self.shield_handler.create_shield(env=self.env)
  28. self.shield = shield
  29. return obs, infos
  30. def step(self, action):
  31. obs, rew, done, truncated, info = self.env.step(action)
  32. info["no_shield_action"] = not self.shield.__contains__(self.env.get_symbolic_state())
  33. return obs, rew, done, truncated, info
  34. def parse_sb3_arguments():
  35. parser = common_parser()
  36. args = parser.parse_args()
  37. return args
  38. class ImageRecorderCallback(BaseCallback):
  39. def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
  40. super().__init__(verbose)
  41. self._eval_env = eval_env
  42. self._render_freq = render_freq
  43. self._n_eval_episodes = n_eval_episodes
  44. self._deterministic = deterministic
  45. self._evaluation_method = evaluation_method
  46. self._log_dir = log_dir
  47. def _on_training_start(self):
  48. image = self.training_env.render(mode="rgb_array")
  49. self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
  50. def _on_step(self) -> bool:
  51. #if self.n_calls % self._render_freq == 0:
  52. # self.record_video()
  53. return True
  54. def _on_training_end(self) -> None:
  55. self.record_video()
  56. def record_video(self) -> bool:
  57. screens = []
  58. def grab_screens(_locals, _globals) -> None:
  59. """
  60. Renders the environment in its current state, recording the screen in the captured `screens` list
  61. :param _locals: A dictionary containing all local variables of the callback's scope
  62. :param _globals: A dictionary containing all global variables of the callback's scope
  63. """
  64. screen = self._eval_env.render()
  65. screens.append(screen)
  66. self._evaluation_method(
  67. self.model,
  68. self._eval_env,
  69. callback=grab_screens,
  70. n_eval_episodes=self._n_eval_episodes,
  71. deterministic=self._deterministic,
  72. )
  73. clip = ImageSequenceClip(list(screens), fps=3)
  74. clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
  75. return True
  76. class InfoCallback(BaseCallback):
  77. """
  78. Custom callback for plotting additional values in tensorboard.
  79. """
  80. def __init__(self, verbose=0):
  81. super().__init__(verbose)
  82. self.sum_goal = 0
  83. self.sum_lava = 0
  84. self.sum_collisions = 0
  85. self.sum_opened_door = 0
  86. self.sum_picked_up = 0
  87. self.no_shield_action = 0
  88. def _on_step(self) -> bool:
  89. infos = self.locals["infos"][0]
  90. if infos["reached_goal"]:
  91. self.sum_goal += 1
  92. if infos["ran_into_lava"]:
  93. self.sum_lava += 1
  94. self.logger.record("info/sum_reached_goal", self.sum_goal)
  95. self.logger.record("info/sum_ran_into_lava", self.sum_lava)
  96. if "collision" in infos:
  97. if infos["collision"]:
  98. self.sum_collisions += 1
  99. self.logger.record("info/sum_collision", self.sum_collisions)
  100. if "opened_door" in infos:
  101. if infos["opened_door"]:
  102. self.sum_opened_door += 1
  103. self.logger.record("info/sum_opened_door", self.sum_opened_door)
  104. if "picked_up" in infos:
  105. if infos["picked_up"]:
  106. self.sum_picked_up += 1
  107. self.logger.record("info/sum_picked_up", self.sum_picked_up)
  108. if "no_shield_action" in infos:
  109. if infos["no_shield_action"]:
  110. self.no_shield_action += 1
  111. self.logger.record("info/no_shield_action", self.no_shield_action)
  112. return True