Browse Source

set sb3 device to auto

This automatically detects whether a GPU can be used for training.
refactoring
sp 9 months ago
parent
commit
36c04f1b81
  1. 2
      examples/shields/rl/13_minigridsb.py

2
examples/shields/rl/13_minigridsb.py

@ -35,7 +35,7 @@ def main():
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
logDir = create_log_dir(args)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto")
evalCallback = EvalCallback(env, best_model_save_path=logDir,
log_path=logDir, eval_freq=max(500, int(args.steps/30)),

Loading…
Cancel
Save