Browse Source

log to stdout

refactoring
sp 9 months ago
parent
commit
a69dab422c
  1. 6
      examples/shields/rl/13_minigridsb.py

6
examples/shields/rl/13_minigridsb.py

@ -1,7 +1,7 @@
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
import gymnasium as gym
@ -14,7 +14,7 @@ from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWr
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
from stable_baselines3.common.callbacks import EvalCallback
import os
import os, sys
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
def mask_fn(env: gym.Env):
@ -31,7 +31,7 @@ def main():
shield_value = args.shield_value
shield_comparison = args.shield_comparison
log_dir = create_log_dir(args)
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)])
env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env)

Loading…
Cancel
Save