|
|
@ -1,7 +1,6 @@ |
|
|
|
import gymnasium as gym |
|
|
|
import numpy as np |
|
|
|
import random |
|
|
|
from moviepy.editor import ImageSequenceClip |
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, common_parser |
|
|
|
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback |
|
|
@ -45,51 +44,6 @@ def parse_sb3_arguments(): |
|
|
|
|
|
|
|
return args |
|
|
|
|
|
|
|
class ImageRecorderCallback(BaseCallback): |
|
|
|
def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0): |
|
|
|
super().__init__(verbose) |
|
|
|
|
|
|
|
self._eval_env = eval_env |
|
|
|
self._render_freq = render_freq |
|
|
|
self._n_eval_episodes = n_eval_episodes |
|
|
|
self._deterministic = deterministic |
|
|
|
self._evaluation_method = evaluation_method |
|
|
|
self._log_dir = log_dir |
|
|
|
|
|
|
|
def _on_training_start(self): |
|
|
|
image = self.training_env.render(mode="rgb_array") |
|
|
|
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) |
|
|
|
|
|
|
|
def _on_step(self) -> bool: |
|
|
|
#if self.n_calls % self._render_freq == 0: |
|
|
|
# self.record_video() |
|
|
|
return True |
|
|
|
|
|
|
|
def _on_training_end(self) -> None: |
|
|
|
self.record_video() |
|
|
|
|
|
|
|
def record_video(self) -> bool: |
|
|
|
screens = [] |
|
|
|
def grab_screens(_locals, _globals) -> None: |
|
|
|
""" |
|
|
|
Renders the environment in its current state, recording the screen in the captured `screens` list |
|
|
|
|
|
|
|
:param _locals: A dictionary containing all local variables of the callback's scope |
|
|
|
:param _globals: A dictionary containing all global variables of the callback's scope |
|
|
|
""" |
|
|
|
screen = self._eval_env.render() |
|
|
|
screens.append(screen) |
|
|
|
self._evaluation_method( |
|
|
|
self.model, |
|
|
|
self._eval_env, |
|
|
|
callback=grab_screens, |
|
|
|
n_eval_episodes=self._n_eval_episodes, |
|
|
|
deterministic=self._deterministic, |
|
|
|
) |
|
|
|
|
|
|
|
clip = ImageSequenceClip(list(screens), fps=3) |
|
|
|
clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3) |
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
class InfoCallback(BaseCallback): |
|
|
|
xxxxxxxxxx