Browse Source

configure logger manually to change csv filename

refactoring
sp 10 months ago
parent
commit
b696dac5f6
  1. 6
      examples/shields/rl/13_minigridsb.py

6
examples/shields/rl/13_minigridsb.py

@ -1,7 +1,7 @@
from sb3_contrib import MaskablePPO from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.wrappers import ActionMasker from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.logger import configure
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat
import gymnasium as gym import gymnasium as gym
@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
import time import time
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
from stable_baselines3.common.callbacks import EvalCallback from stable_baselines3.common.callbacks import EvalCallback
@ -31,7 +31,7 @@ def main():
shield_value = args.shield_value shield_value = args.shield_value
shield_comparison = args.shield_comparison shield_comparison = args.shield_comparison
log_dir = create_log_dir(args) log_dir = create_log_dir(args)
new_logger = configure(log_dir, ["csv", "tensorboard"])
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env) env = RGBImgObsWrapper(env)

Loading…
Cancel
Save