diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index b9a1e26..e0e9cd0 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,6 +1,7 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.wrappers import ActionMasker +from stable_baselines3.common.logger import configure import gymnasium as gym @@ -30,6 +31,7 @@ def main(): shield_value = args.shield_value shield_comparison = args.shield_comparison log_dir = create_log_dir(args) + new_logger = configure(log_dir, ["csv", "tensorboard"]) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env) @@ -42,6 +44,7 @@ def main(): else: env = ActionMasker(env, nomask_fn) model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") + model.set_logger(new_logger) evalCallback = EvalCallback(env, best_model_save_path=log_dir,