Browse Source

record gifs when evaluating

This currently does not use tensorboards feature, we have to leave this
as a TBD. The gifs will be left in the experiments log_dir
refactoring
sp 9 months ago
parent
commit
24dec631aa
  1. 11
      examples/shields/rl/13_minigridsb.py
  2. 40
      examples/shields/rl/sb3utils.py

11
examples/shields/rl/13_minigridsb.py

@ -34,6 +34,7 @@ def main():
shield_comparison = args.shield_comparison
log_dir = create_log_dir(args)
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
#new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)])
if shield_needed(args.shielding):
@ -68,19 +69,27 @@ def main():
model.set_logger(new_logger)
steps = args.steps
# Evaluation
eval_freq=max(500, int(args.steps/30))
n_eval_episodes=5
render_freq = eval_freq
if shielded_evaluation(args.shielding):
from sb3_contrib.common.maskable.evaluation import evaluate_policy
evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
else:
from stable_baselines3.common.evaluation import evaluate_policy
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback])
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
#vec_env = model.get_env()
#obs = vec_env.reset()

40
examples/shields/rl/sb3utils.py

@ -1,6 +1,7 @@
import gymnasium as gym
import numpy as np
import random
from moviepy.editor import ImageSequenceClip
from utils import MiniGridShieldHandler, common_parser
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
@ -45,14 +46,49 @@ def parse_sb3_arguments():
return args
class ImageRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
super().__init__(verbose)
self._eval_env = eval_env
self._render_freq = render_freq
self._n_eval_episodes = n_eval_episodes
self._deterministic = deterministic
self._evaluation_method = evaluation_method
self._log_dir = log_dir
def _on_training_start(self):
image = self.training_env.render(mode="rgb_array")
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
def _on_step(self):
def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
self.record_video()
return True
def _on_training_end(self) -> None:
self.record_video()
def record_video(self) -> bool:
screens = []
def grab_screens(_locals, _globals) -> None:
"""
Renders the environment in its current state, recording the screen in the captured `screens` list
:param _locals: A dictionary containing all local variables of the callback's scope
:param _globals: A dictionary containing all global variables of the callback's scope
"""
screen = self._eval_env.render()
screens.append(screen)
self._evaluation_method(
self.model,
self._eval_env,
callback=grab_screens,
n_eval_episodes=self._n_eval_episodes,
deterministic=self._deterministic,
)
clip = ImageSequenceClip(list(screens), fps=3)
clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
return True

Loading…
Cancel
Save