diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py
index f18f2af..f260c62 100644
--- a/examples/shields/rl/sb3utils.py
+++ b/examples/shields/rl/sb3utils.py
@@ -3,6 +3,8 @@ import numpy as np
 import random
 
 from utils import MiniGridShieldHandler, common_parser
+from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
+from stable_baselines3.common.logger import Image
 
 class MiniGridSbShieldingWrapper(gym.core.Wrapper):
     def __init__(self,
@@ -43,3 +45,40 @@ def parse_sb3_arguments():
     args = parser.parse_args()
 
     return args
+
+class ImageRecorderCallback(BaseCallback):
+    def __init__(self, verbose=0):
+        super().__init__(verbose)
+
+    def _on_training_start(self):
+        image = self.training_env.render(mode="rgb_array")
+        self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
+
+    def _on_step(self):
+        return True
+
+
+class InfoCallback(BaseCallback):
+    """
+    Custom callback for plotting additional values in tensorboard.
+    """
+
+    def __init__(self, verbose=0):
+        super().__init__(verbose)
+        self.sum_goal = 0
+        self.sum_lava = 0
+        self.sum_collisions = 0
+
+    def _on_step(self) -> bool:
+        infos = self.locals["infos"][0]
+        if infos["reached_goal"]:
+            self.sum_goal += 1
+        if infos["ran_into_lava"]:
+            self.sum_lava += 1
+        self.logger.record("info/sum_reached_goal", self.sum_goal)
+        self.logger.record("info/sum_ran_into_lava", self.sum_lava)
+        if "collision" in infos:
+            if infos["collision"]:
+                self.sum_collision += 1
+            self.logger.record("info/sum_collision", sum_collisions)
+        return True