From a69dab422c057df126387dbdf25e80be4573ada9 Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 18:32:15 +0100 Subject: [PATCH] log to stdout --- examples/shields/rl/13_minigridsb.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 92bf2b4..f411aaf 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,7 +1,7 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.wrappers import ActionMasker -from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat +from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat import gymnasium as gym @@ -14,7 +14,7 @@ from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWr from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback from stable_baselines3.common.callbacks import EvalCallback -import os +import os, sys GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") def mask_fn(env: gym.Env): @@ -31,7 +31,7 @@ def main(): shield_value = args.shield_value shield_comparison = args.shield_comparison log_dir = create_log_dir(args) - new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) + new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env)