Browse Source

log to csv and tensorboard only

refactoring
sp 9 months ago
parent
commit
703b213248
  1. 3
      examples/shields/rl/13_minigridsb.py

3
examples/shields/rl/13_minigridsb.py

@ -1,6 +1,7 @@
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.logger import configure
import gymnasium as gym
@ -30,6 +31,7 @@ def main():
shield_value = args.shield_value
shield_comparison = args.shield_comparison
log_dir = create_log_dir(args)
new_logger = configure(log_dir, ["csv", "tensorboard"])
env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env)
@ -42,6 +44,7 @@ def main():
else:
env = ActionMasker(env, nomask_fn)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
model.set_logger(new_logger)
evalCallback = EvalCallback(env, best_model_save_path=log_dir,

Loading…
Cancel
Save