|
|
@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger |
|
|
|
from ray.rllib.models import ModelCatalog |
|
|
|
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger |
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
|
from ray.rllib.algorithm.callbacks import make_multi_callbacks |
|
|
|
from ray.air import session |
|
|
|
|
|
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
@ -78,7 +79,7 @@ def ppo(args): |
|
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, |
|
|
|
},) |
|
|
|
.framework("torch") |
|
|
|
.callbacks([MyCallbacks, ShieldInfoCallback]) |
|
|
|
.callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) |
|
|
|
.evaluation(evaluation_config={ |
|
|
|
"evaluation_interval": 1, |
|
|
|
"evaluation_duration": 10, |
|
|
|