diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 578bf2e..175dc8e 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -34,6 +34,7 @@ def main(): shield_comparison = args.shield_comparison log_dir = create_log_dir(args) new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) + #new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) if shield_needed(args.shielding): @@ -68,19 +69,27 @@ def main(): model.set_logger(new_logger) steps = args.steps + + # Evaluation eval_freq=max(500, int(args.steps/30)) n_eval_episodes=5 + render_freq = eval_freq if shielded_evaluation(args.shielding): + from sb3_contrib.common.maskable.evaluation import evaluate_policy evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=eval_freq, deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) else: + from stable_baselines3.common.evaluation import evaluate_policy evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=eval_freq, deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) + - model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) + model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) #vec_env = model.get_env() #obs = vec_env.reset() diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index bcae0a3..4aff45c 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -1,6 +1,7 @@ import gymnasium as gym import numpy as np import random +from moviepy.editor import ImageSequenceClip from utils import MiniGridShieldHandler, common_parser from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback @@ -45,14 +46,49 @@ def parse_sb3_arguments(): return args class ImageRecorderCallback(BaseCallback): - def __init__(self, verbose=0): + def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0): super().__init__(verbose) + self._eval_env = eval_env + self._render_freq = render_freq + self._n_eval_episodes = n_eval_episodes + self._deterministic = deterministic + self._evaluation_method = evaluation_method + self._log_dir = log_dir + def _on_training_start(self): image = self.training_env.render(mode="rgb_array") self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) - def _on_step(self): + def _on_step(self) -> bool: + if self.n_calls % self._render_freq == 0: + self.record_video() + return True + + def _on_training_end(self) -> None: + self.record_video() + + def record_video(self) -> bool: + screens = [] + def grab_screens(_locals, _globals) -> None: + """ + Renders the environment in its current state, recording the screen in the captured `screens` list + + :param _locals: A dictionary containing all local variables of the callback's scope + :param _globals: A dictionary containing all global variables of the callback's scope + """ + screen = self._eval_env.render() + screens.append(screen) + self._evaluation_method( + self.model, + self._eval_env, + callback=grab_screens, + n_eval_episodes=self._n_eval_episodes, + deterministic=self._deterministic, + ) + + clip = ImageSequenceClip(list(screens), fps=3) + clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3) return True