diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index e2d5c41..fd4c958 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -65,32 +65,32 @@ def main(): eval_env = ActionMasker(eval_env, nomask_fn) else: assert(False) # TODO Do something proper - model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") - model.set_logger(new_logger) - steps = args.steps - - - # Evaluation - eval_freq=max(500, int(args.steps/30)) - n_eval_episodes=5 - render_freq = eval_freq - if shielded_evaluation(args.shielding): - from sb3_contrib.common.maskable.evaluation import evaluate_policy - evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, - log_path=log_dir, eval_freq=eval_freq, - deterministic=True, render=False, n_eval_episodes=n_eval_episodes) - imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - else: - from stable_baselines3.common.evaluation import evaluate_policy - evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, - log_path=log_dir, eval_freq=eval_freq, - deterministic=True, render=False, n_eval_episodes=n_eval_episodes) - - imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - - - model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) - model.save(f"{log_dir}/{expname(args)}") + #model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") + #model.set_logger(new_logger) + #steps = args.steps + + + ## Evaluation + #eval_freq=max(500, int(args.steps/30)) + #n_eval_episodes=5 + #render_freq = eval_freq + #if shielded_evaluation(args.shielding): + # from sb3_contrib.common.maskable.evaluation import evaluate_policy + # evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, + # log_path=log_dir, eval_freq=eval_freq, + # deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) + #else: + # from stable_baselines3.common.evaluation import evaluate_policy + # evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, + # log_path=log_dir, eval_freq=eval_freq, + # deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + + # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) + + + #model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) + #model.save(f"{log_dir}/{expname(args)}") if __name__ == '__main__':