diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index b9abc5e..8b29ced 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithm.callbacks import make_multi_callbacks from ray.air import session from torch_action_mask_model import TorchActionMaskModel @@ -78,7 +79,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks([MyCallbacks, ShieldInfoCallback]) + .callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10,