Browse Source

added useful sb3 callbacks

refactoring
sp 1 year ago
parent
commit
315b0c8e7d
  1. 39
      examples/shields/rl/sb3utils.py

39
examples/shields/rl/sb3utils.py

@ -3,6 +3,8 @@ import numpy as np
import random
from utils import MiniGridShieldHandler, common_parser
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
from stable_baselines3.common.logger import Image
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
@ -43,3 +45,40 @@ def parse_sb3_arguments():
args = parser.parse_args()
return args
class ImageRecorderCallback(BaseCallback):
def __init__(self, verbose=0):
super().__init__(verbose)
def _on_training_start(self):
image = self.training_env.render(mode="rgb_array")
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
def _on_step(self):
return True
class InfoCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super().__init__(verbose)
self.sum_goal = 0
self.sum_lava = 0
self.sum_collisions = 0
def _on_step(self) -> bool:
infos = self.locals["infos"][0]
if infos["reached_goal"]:
self.sum_goal += 1
if infos["ran_into_lava"]:
self.sum_lava += 1
self.logger.record("info/sum_reached_goal", self.sum_goal)
self.logger.record("info/sum_ran_into_lava", self.sum_lava)
if "collision" in infos:
if infos["collision"]:
self.sum_collision += 1
self.logger.record("info/sum_collision", sum_collisions)
return True
Loading…
Cancel
Save