|
|
@ -35,7 +35,7 @@ def main(): |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
logDir = create_log_dir(args) |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir) |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") |
|
|
|
|
|
|
|
evalCallback = EvalCallback(env, best_model_save_path=logDir, |
|
|
|
log_path=logDir, eval_freq=max(500, int(args.steps/30)), |
|
|
|