From adfb4034ce53043769e3dae8ad6dfd0aa7c4c781 Mon Sep 17 00:00:00 2001 From: sp Date: Sat, 30 Dec 2023 11:53:21 +0100 Subject: [PATCH] make multi callbacks --- examples/shields/rl/15_train_eval_tune.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index b9abc5e..8b29ced 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithm.callbacks import make_multi_callbacks from ray.air import session from torch_action_mask_model import TorchActionMaskModel @@ -78,7 +79,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks([MyCallbacks, ShieldInfoCallback]) + .callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10,