diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index fd4c958..e2d5c41 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -65,32 +65,32 @@ def main(): eval_env = ActionMasker(eval_env, nomask_fn) else: assert(False) # TODO Do something proper - #model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") - #model.set_logger(new_logger) - #steps = args.steps - - - ## Evaluation - #eval_freq=max(500, int(args.steps/30)) - #n_eval_episodes=5 - #render_freq = eval_freq - #if shielded_evaluation(args.shielding): - # from sb3_contrib.common.maskable.evaluation import evaluate_policy - # evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, - # log_path=log_dir, eval_freq=eval_freq, - # deterministic=True, render=False, n_eval_episodes=n_eval_episodes) - # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - #else: - # from stable_baselines3.common.evaluation import evaluate_policy - # evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, - # log_path=log_dir, eval_freq=eval_freq, - # deterministic=True, render=False, n_eval_episodes=n_eval_episodes) - - # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - - - #model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) - #model.save(f"{log_dir}/{expname(args)}") + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") + model.set_logger(new_logger) + steps = args.steps + + + # Evaluation + eval_freq=max(500, int(args.steps/30)) + n_eval_episodes=5 + render_freq = eval_freq + if shielded_evaluation(args.shielding): + from sb3_contrib.common.maskable.evaluation import evaluate_policy + evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, + log_path=log_dir, eval_freq=eval_freq, + deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) + else: + from stable_baselines3.common.evaluation import evaluate_policy + evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, + log_path=log_dir, eval_freq=eval_freq, + deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + + imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) + + + model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) + model.save(f"{log_dir}/{expname(args)}") if __name__ == '__main__':