|
@ -78,7 +78,7 @@ def ppo(args): |
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, |
|
|
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, |
|
|
},) |
|
|
},) |
|
|
.framework("torch") |
|
|
.framework("torch") |
|
|
.callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]) |
|
|
|
|
|
|
|
|
.callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12])) |
|
|
.evaluation(evaluation_config={ |
|
|
.evaluation(evaluation_config={ |
|
|
"evaluation_interval": 1, |
|
|
"evaluation_interval": 1, |
|
|
"evaluation_duration": 10, |
|
|
"evaluation_duration": 10, |
|
@ -116,8 +116,7 @@ def ppo(args): |
|
|
name=test_name(args), |
|
|
name=test_name(args), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
, |
|
|
|
|
|
|
|
|
), |
|
|
param_space=config,) |
|
|
param_space=config,) |
|
|
|
|
|
|
|
|
results = tuner.fit() |
|
|
results = tuner.fit() |
|
|