From 36c04f1b81beaca2eb2ee78b24f5679a47e60d1d Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 09:58:28 +0100 Subject: [PATCH] set sb3 device to auto This automatically detects whether a GPU can be used for training. --- examples/shields/rl/13_minigridsb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index ccc72dc..c842d15 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -35,7 +35,7 @@ def main(): env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) logDir = create_log_dir(args) - model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir) + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") evalCallback = EvalCallback(env, best_model_save_path=logDir, log_path=logDir, eval_freq=max(500, int(args.steps/30)),