Browse Source

removed VideoCallback

and moviepy dependency
main
sp 2 months ago
parent
commit
df402f9c4c
  1. 1
      dockerfile
  2. 46
      notebooks/sb3utils.py

1
dockerfile

@ -83,7 +83,6 @@ RUN pip install ipywidgets
RUN pip install matplotlib
RUN pip install sb3-contrib
RUN pip install opencv-python
RUN pip install moviepy
RUN pip install gymnasium==0.29.0
RUN pip install numpy==1.24.4

46
notebooks/sb3utils.py

@ -1,7 +1,6 @@
import gymnasium as gym
import numpy as np
import random
from moviepy.editor import ImageSequenceClip
from utils import MiniGridShieldHandler, common_parser
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
@ -45,51 +44,6 @@ def parse_sb3_arguments():
return args
class ImageRecorderCallback(BaseCallback):
def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
super().__init__(verbose)
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
self._evaluation_method = evaluation_method
self._log_dir = log_dir
def _on_training_start(self):
image = self.training_env.render(mode="rgb_array")
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
def _on_step(self) -> bool:
#if self.n_calls % self._render_freq == 0:
# self.record_video()
return True
def _on_training_end(self) -> None:
self.record_video()
def record_video(self) -> bool:
screens = []
def grab_screens(_locals, _globals) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
screen = self._eval_env.render()
screens.append(screen)
self._evaluation_method(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
clip = ImageSequenceClip(list(screens), fps=3)
clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
return True
class InfoCallback(BaseCallback):

|||||||
100:0
Loading…
Cancel
Save