diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 2170a9e..bc4eea2 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -69,8 +69,9 @@ def main(): # Evaluation - eval_freq=max(500, int(args.steps/30)) - n_eval_episodes=5 + #eval_freq=max(500, int(args.steps/30)) + eval_freq=10000 + n_eval_episodes=25 render_freq = eval_freq if shielded_evaluation(args.shielding): from sb3_contrib.common.maskable.evaluation import evaluate_policy @@ -86,7 +87,7 @@ def main(): imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()]) + model.learn(steps,callback=[evalCallback, imageAndVideoCallback, InfoCallback()]) model.save(f"{log_dir}/{expname(args)}")