|
|
@ -69,8 +69,9 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
# Evaluation |
|
|
|
eval_freq=max(500, int(args.steps/30)) |
|
|
|
n_eval_episodes=5 |
|
|
|
#eval_freq=max(500, int(args.steps/30)) |
|
|
|
eval_freq=10000 |
|
|
|
n_eval_episodes=25 |
|
|
|
render_freq = eval_freq |
|
|
|
if shielded_evaluation(args.shielding): |
|
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
@ -86,7 +87,7 @@ def main(): |
|
|
|
|
|
|
|
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) |
|
|
|
|
|
|
|
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()]) |
|
|
|
model.learn(steps,callback=[evalCallback, imageAndVideoCallback, InfoCallback()]) |
|
|
|
model.save(f"{log_dir}/{expname(args)}") |
|
|
|
|
|
|
|
|
|
|
|