You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
4.7 KiB

11 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from moviepy.editor import ImageSequenceClip
  5. from utils import MiniGridShieldHandler, common_parser
  6. from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
  7. from stable_baselines3.common.logger import Image
  8. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  9. def __init__(self,
  10. env,
  11. shield_handler : MiniGridShieldHandler,
  12. create_shield_at_reset = False,
  13. ):
  14. super().__init__(env)
  15. self.shield_handler = shield_handler
  16. self.create_shield_at_reset = create_shield_at_reset
  17. shield = self.shield_handler.create_shield(env=self.env)
  18. self.shield = shield
  19. def create_action_mask(self):
  20. try:
  21. return self.shield[self.env.get_symbolic_state()]
  22. except:
  23. return [0.0] * 3 + [0.0] * 4
  24. def reset(self, *, seed=None, options=None):
  25. obs, infos = self.env.reset(seed=seed, options=options)
  26. if self.create_shield_at_reset:
  27. shield = self.shield_handler.create_shield(env=self.env)
  28. self.shield = shield
  29. return obs, infos
  30. def step(self, action):
  31. obs, rew, done, truncated, info = self.env.step(action)
  32. info["no_shield_action"] = not self.shield.__contains__(self.env.get_symbolic_state())
  33. return obs, rew, done, truncated, info
  34. def parse_sb3_arguments():
  35. parser = common_parser()
  36. args = parser.parse_args()
  37. return args
  38. class ImageRecorderCallback(BaseCallback):
  39. def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
  40. super().__init__(verbose)
  41. self._eval_env = eval_env
  42. self._render_freq = render_freq
  43. self._n_eval_episodes = n_eval_episodes
  44. self._deterministic = deterministic
  45. self._evaluation_method = evaluation_method
  46. self._log_dir = log_dir
  47. def _on_training_start(self):
  48. image = self.training_env.render(mode="rgb_array")
  49. self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
  50. def _on_step(self) -> bool:
  51. #if self.n_calls % self._render_freq == 0:
  52. # self.record_video()
  53. return True
  54. def _on_training_end(self) -> None:
  55. self.record_video()
  56. def record_video(self) -> bool:
  57. screens = []
  58. def grab_screens(_locals, _globals) -> None:
  59. """
  60. Renders the environment in its current state, recording the screen in the captured `screens` list
  61. :param _locals: A dictionary containing all local variables of the callback's scope
  62. :param _globals: A dictionary containing all global variables of the callback's scope
  63. """
  64. screen = self._eval_env.render()
  65. screens.append(screen)
  66. self._evaluation_method(
  67. self.model,
  68. self._eval_env,
  69. callback=grab_screens,
  70. n_eval_episodes=self._n_eval_episodes,
  71. deterministic=self._deterministic,
  72. )
  73. clip = ImageSequenceClip(list(screens), fps=3)
  74. clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
  75. return True
  76. class InfoCallback(BaseCallback):
  77. """
  78. Custom callback for plotting additional values in tensorboard.
  79. """
  80. def __init__(self, verbose=0):
  81. super().__init__(verbose)
  82. self.sum_goal = 0
  83. self.sum_lava = 0
  84. self.sum_collisions = 0
  85. self.sum_opened_door = 0
  86. self.sum_picked_up = 0
  87. self.no_shield_action = 0
  88. def _on_step(self) -> bool:
  89. infos = self.locals["infos"][0]
  90. if infos["reached_goal"]:
  91. self.sum_goal += 1
  92. if infos["ran_into_lava"]:
  93. self.sum_lava += 1
  94. self.logger.record("info/sum_reached_goal", self.sum_goal)
  95. self.logger.record("info/sum_ran_into_lava", self.sum_lava)
  96. if "collision" in infos:
  97. if infos["collision"]:
  98. self.sum_collisions += 1
  99. self.logger.record("info/sum_collision", self.sum_collisions)
  100. if "opened_door" in infos:
  101. if infos["opened_door"]:
  102. self.sum_opened_door += 1
  103. self.logger.record("info/sum_opened_door", self.sum_opened_door)
  104. if "picked_up" in infos:
  105. if infos["picked_up"]:
  106. self.sum_picked_up += 1
  107. self.logger.record("info/sum_picked_up", self.sum_picked_up)
  108. if "no_shield_action" in infos:
  109. if infos["no_shield_action"]:
  110. self.no_shield_action += 1
  111. self.logger.record("info/no_shield_action", self.no_shield_action)
  112. return True