|
|
@ -19,7 +19,7 @@ from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name |
|
|
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from callbacks import MyCallbacks, ShieldInfoCallback |
|
|
|
from callbacks import MyCallbacks |
|
|
|
|
|
|
|
|
|
|
|
def shielding_env_creater(config): |
|
|
@ -79,7 +79,7 @@ def ppo(args): |
|
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, |
|
|
|
},) |
|
|
|
.framework("torch") |
|
|
|
.callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) |
|
|
|
.callbacks(MyCallbacks) |
|
|
|
.evaluation(evaluation_config={ |
|
|
|
"evaluation_interval": 1, |
|
|
|
"evaluation_duration": 10, |
|
|
|