Browse Source

make multi callbacks

refactoring
sp 11 months ago
parent
commit
adfb4034ce
  1. 3
      examples/shields/rl/15_train_eval_tune.py

3
examples/shields/rl/15_train_eval_tune.py

@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithm.callbacks import make_multi_callbacks
from ray.air import session from ray.air import session
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
@ -78,7 +79,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},) },)
.framework("torch") .framework("torch")
.callbacks([MyCallbacks, ShieldInfoCallback])
.callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback]))
.evaluation(evaluation_config={ .evaluation(evaluation_config={
"evaluation_interval": 1, "evaluation_interval": 1,
"evaluation_duration": 10, "evaluation_duration": 10,

Loading…
Cancel
Save