You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
3.9 KiB

11 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from moviepy.editor import ImageSequenceClip
  5. from utils import MiniGridShieldHandler, common_parser
  6. from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
  7. from stable_baselines3.common.logger import Image
  8. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  9. def __init__(self,
  10. env,
  11. shield_handler : MiniGridShieldHandler,
  12. create_shield_at_reset = False,
  13. ):
  14. super().__init__(env)
  15. self.shield_handler = shield_handler
  16. self.create_shield_at_reset = create_shield_at_reset
  17. shield = self.shield_handler.create_shield(env=self.env)
  18. self.shield = shield
  19. def create_action_mask(self):
  20. try:
  21. return self.shield[self.env.get_symbolic_state()]
  22. except:
  23. return [1.0] * 3 + [1.0] * 4
  24. def reset(self, *, seed=None, options=None):
  25. obs, infos = self.env.reset(seed=seed, options=options)
  26. if self.create_shield_at_reset:
  27. shield = self.shield_handler.create_shield(env=self.env)
  28. self.shield = shield
  29. return obs, infos
  30. def step(self, action):
  31. obs, rew, done, truncated, info = self.env.step(action)
  32. return obs, rew, done, truncated, info
  33. def parse_sb3_arguments():
  34. parser = common_parser()
  35. args = parser.parse_args()
  36. return args
  37. class ImageRecorderCallback(BaseCallback):
  38. def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
  39. super().__init__(verbose)
  40. self._eval_env = eval_env
  41. self._render_freq = render_freq
  42. self._n_eval_episodes = n_eval_episodes
  43. self._deterministic = deterministic
  44. self._evaluation_method = evaluation_method
  45. self._log_dir = log_dir
  46. def _on_training_start(self):
  47. image = self.training_env.render(mode="rgb_array")
  48. self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
  49. def _on_step(self) -> bool:
  50. #if self.n_calls % self._render_freq == 0:
  51. # self.record_video()
  52. return True
  53. def _on_training_end(self) -> None:
  54. self.record_video()
  55. def record_video(self) -> bool:
  56. screens = []
  57. def grab_screens(_locals, _globals) -> None:
  58. """
  59. Renders the environment in its current state, recording the screen in the captured `screens` list
  60. :param _locals: A dictionary containing all local variables of the callback's scope
  61. :param _globals: A dictionary containing all global variables of the callback's scope
  62. """
  63. screen = self._eval_env.render()
  64. screens.append(screen)
  65. self._evaluation_method(
  66. self.model,
  67. self._eval_env,
  68. callback=grab_screens,
  69. n_eval_episodes=self._n_eval_episodes,
  70. deterministic=self._deterministic,
  71. )
  72. clip = ImageSequenceClip(list(screens), fps=3)
  73. clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
  74. return True
  75. class InfoCallback(BaseCallback):
  76. """
  77. Custom callback for plotting additional values in tensorboard.
  78. """
  79. def __init__(self, verbose=0):
  80. super().__init__(verbose)
  81. self.sum_goal = 0
  82. self.sum_lava = 0
  83. self.sum_collisions = 0
  84. def _on_step(self) -> bool:
  85. infos = self.locals["infos"][0]
  86. if infos["reached_goal"]:
  87. self.sum_goal += 1
  88. if infos["ran_into_lava"]:
  89. self.sum_lava += 1
  90. self.logger.record("info/sum_reached_goal", self.sum_goal)
  91. self.logger.record("info/sum_ran_into_lava", self.sum_lava)
  92. if "collision" in infos:
  93. if infos["collision"]:
  94. self.sum_collisions += 1
  95. self.logger.record("info/sum_collision", self.sum_collisions)
  96. return True