|
@ -3,6 +3,8 @@ import numpy as np |
|
|
import random |
|
|
import random |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, common_parser |
|
|
from utils import MiniGridShieldHandler, common_parser |
|
|
|
|
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback |
|
|
|
|
|
from stable_baselines3.common.logger import Image |
|
|
|
|
|
|
|
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
@ -43,3 +45,40 @@ def parse_sb3_arguments(): |
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
return args |
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
class ImageRecorderCallback(BaseCallback): |
|
|
|
|
|
def __init__(self, verbose=0): |
|
|
|
|
|
super().__init__(verbose) |
|
|
|
|
|
|
|
|
|
|
|
def _on_training_start(self): |
|
|
|
|
|
image = self.training_env.render(mode="rgb_array") |
|
|
|
|
|
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) |
|
|
|
|
|
|
|
|
|
|
|
def _on_step(self): |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InfoCallback(BaseCallback): |
|
|
|
|
|
""" |
|
|
|
|
|
Custom callback for plotting additional values in tensorboard. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, verbose=0): |
|
|
|
|
|
super().__init__(verbose) |
|
|
|
|
|
self.sum_goal = 0 |
|
|
|
|
|
self.sum_lava = 0 |
|
|
|
|
|
self.sum_collisions = 0 |
|
|
|
|
|
|
|
|
|
|
|
def _on_step(self) -> bool: |
|
|
|
|
|
infos = self.locals["infos"][0] |
|
|
|
|
|
if infos["reached_goal"]: |
|
|
|
|
|
self.sum_goal += 1 |
|
|
|
|
|
if infos["ran_into_lava"]: |
|
|
|
|
|
self.sum_lava += 1 |
|
|
|
|
|
self.logger.record("info/sum_reached_goal", self.sum_goal) |
|
|
|
|
|
self.logger.record("info/sum_ran_into_lava", self.sum_lava) |
|
|
|
|
|
if "collision" in infos: |
|
|
|
|
|
if infos["collision"]: |
|
|
|
|
|
self.sum_collision += 1 |
|
|
|
|
|
self.logger.record("info/sum_collision", sum_collisions) |
|
|
|
|
|
return True |