diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py
index ccc72dc..c842d15 100644
--- a/examples/shields/rl/13_minigridsb.py
+++ b/examples/shields/rl/13_minigridsb.py
@@ -35,7 +35,7 @@ def main():
     env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
     env = ActionMasker(env, mask_fn)
     logDir = create_log_dir(args)
-    model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir)
+    model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto")
 
     evalCallback = EvalCallback(env, best_model_save_path=logDir,
                                 log_path=logDir, eval_freq=max(500,  int(args.steps/30)),