Browse Source

only create ShieldHandler when necessary

this also renames camelcase variable logDir
refactoring
sp 11 months ago
parent
commit
5ab83b7460
  1. 11
      examples/shields/rl/13_minigridsb.py

11
examples/shields/rl/13_minigridsb.py

@ -29,29 +29,28 @@ def main():
formula = args.formula
shield_value = args.shield_value
shield_comparison = args.shield_comparison
logDir = create_log_dir(args)
log_dir = create_log_dir(args)
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, cleanup=args.cleanup)
env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env)
env = ImgObsWrapper(env)
env = MiniWrapper(env)
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training:
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, cleanup=args.cleanup)
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
env = ActionMasker(env, mask_fn)
else:
env = ActionMasker(env, nomask_fn)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto")
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
evalCallback = EvalCallback(env, best_model_save_path=logDir,
log_path=logDir, eval_freq=max(500, int(args.steps/30)),
evalCallback = EvalCallback(env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=max(500, int(args.steps/30)),
deterministic=True, render=False)
steps = args.steps
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()])
#vec_env = model.get_env()
#obs = vec_env.reset()
#terminated = truncated = False

Loading…
Cancel
Save