From 5ab83b74608650da2786204508ab83f54901c233 Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 13:48:49 +0100 Subject: [PATCH] only create ShieldHandler when necessary this also renames camelcase variable logDir --- examples/shields/rl/13_minigridsb.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index f8ebe87..487a8cb 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -29,29 +29,28 @@ def main(): formula = args.formula shield_value = args.shield_value shield_comparison = args.shield_comparison - logDir = create_log_dir(args) + log_dir = create_log_dir(args) - shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, cleanup=args.cleanup) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env) env = ImgObsWrapper(env) env = MiniWrapper(env) if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: + shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, cleanup=args.cleanup) env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) env = ActionMasker(env, mask_fn) else: env = ActionMasker(env, nomask_fn) - model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") - evalCallback = EvalCallback(env, best_model_save_path=logDir, - log_path=logDir, eval_freq=max(500, int(args.steps/30)), + evalCallback = EvalCallback(env, best_model_save_path=log_dir, + log_path=log_dir, eval_freq=max(500, int(args.steps/30)), deterministic=True, render=False) steps = args.steps model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()]) - #vec_env = model.get_env() #obs = vec_env.reset() #terminated = truncated = False