diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index f18f2af..f260c62 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -3,6 +3,8 @@ import numpy as np import random from utils import MiniGridShieldHandler, common_parser +from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback +from stable_baselines3.common.logger import Image class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, @@ -43,3 +45,40 @@ def parse_sb3_arguments(): args = parser.parse_args() return args + +class ImageRecorderCallback(BaseCallback): + def __init__(self, verbose=0): + super().__init__(verbose) + + def _on_training_start(self): + image = self.training_env.render(mode="rgb_array") + self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) + + def _on_step(self): + return True + + +class InfoCallback(BaseCallback): + """ + Custom callback for plotting additional values in tensorboard. + """ + + def __init__(self, verbose=0): + super().__init__(verbose) + self.sum_goal = 0 + self.sum_lava = 0 + self.sum_collisions = 0 + + def _on_step(self) -> bool: + infos = self.locals["infos"][0] + if infos["reached_goal"]: + self.sum_goal += 1 + if infos["ran_into_lava"]: + self.sum_lava += 1 + self.logger.record("info/sum_reached_goal", self.sum_goal) + self.logger.record("info/sum_ran_into_lava", self.sum_lava) + if "collision" in infos: + if infos["collision"]: + self.sum_collision += 1 + self.logger.record("info/sum_collision", sum_collisions) + return True