From b696dac5f6d508ff4db7013def33ecfd26116674 Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 17:11:17 +0100 Subject: [PATCH] configure logger manually to change csv filename --- examples/shields/rl/13_minigridsb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index e0e9cd0..92bf2b4 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,7 +1,7 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.wrappers import ActionMasker -from stable_baselines3.common.logger import configure +from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat import gymnasium as gym @@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper import time -from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper +from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback from stable_baselines3.common.callbacks import EvalCallback @@ -31,7 +31,7 @@ def main(): shield_value = args.shield_value shield_comparison = args.shield_comparison log_dir = create_log_dir(args) - new_logger = configure(log_dir, ["csv", "tensorboard"]) + new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env)