diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index ccc72dc..c842d15 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -35,7 +35,7 @@ def main(): env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) logDir = create_log_dir(args) - model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir) + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") evalCallback = EvalCallback(env, best_model_save_path=logDir, log_path=logDir, eval_freq=max(500, int(args.steps/30)),