Browse Source

make multi callbacks

refactoring
sp 11 months ago
parent
commit
adfb4034ce
  1. 3
      examples/shields/rl/15_train_eval_tune.py

3
examples/shields/rl/15_train_eval_tune.py

@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger
from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithm.callbacks import make_multi_callbacks
from ray.air import session
from torch_action_mask_model import TorchActionMaskModel
@ -78,7 +79,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},)
.framework("torch")
.callbacks([MyCallbacks, ShieldInfoCallback])
.callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback]))
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,

Loading…
Cancel
Save