diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 07bfaa1..e2d5c41 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -89,21 +89,8 @@ def main(): imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()]) - - #vec_env = model.get_env() - #obs = vec_env.reset() - #terminated = truncated = False - #while not terminated and not truncated: - # action_masks = None - # action, _states = model.predict(obs, action_masks=action_masks) - # print(action) - # obs, reward, terminated, truncated, info = env.step(action) - # # action, _states = model.predict(obs, deterministic=True) - # # obs, rewards, dones, info = vec_env.step(action) - # vec_env.render("human") - # time.sleep(0.2) - + model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) + model.save(f"{log_dir}/{expname(args)}") if __name__ == '__main__':