diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 229b49f..b9abc5e 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -78,7 +78,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks(MyCallbacks, ShieldInfoCallback) + .callbacks([MyCallbacks, ShieldInfoCallback]) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10,