From 71854bae0124472e1b9ac96f888db46f990a109d Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 09:47:36 +0100 Subject: [PATCH] init evalCallback for training with sb3 --- examples/shields/rl/13_minigridsb.py | 39 +++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 0fab627..ccc72dc 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -11,6 +11,7 @@ import time from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback +from stable_baselines3.common.callbacks import EvalCallback import os @@ -33,27 +34,29 @@ def main(): env = MiniWrapper(env) env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) - model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args)) + logDir = create_log_dir(args) + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir) + evalCallback = EvalCallback(env, best_model_save_path=logDir, + log_path=logDir, eval_freq=max(500, int(args.steps/30)), + deterministic=True, render=False) steps = args.steps - model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1) - - - print("Learning done, hit enter") - input("") - vec_env = model.get_env() - obs = vec_env.reset() - terminated = truncated = False - while not terminated and not truncated: - action_masks = None - action, _states = model.predict(obs, action_masks=action_masks) - print(action) - obs, reward, terminated, truncated, info = env.step(action) - # action, _states = model.predict(obs, deterministic=True) - # obs, rewards, dones, info = vec_env.step(action) - vec_env.render("human") - time.sleep(0.2) + model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()]) + + + #vec_env = model.get_env() + #obs = vec_env.reset() + #terminated = truncated = False + #while not terminated and not truncated: + # action_masks = None + # action, _states = model.predict(obs, action_masks=action_masks) + # print(action) + # obs, reward, terminated, truncated, info = env.step(action) + # # action, _states = model.predict(obs, deterministic=True) + # # obs, rewards, dones, info = vec_env.step(action) + # vec_env.render("human") + # time.sleep(0.2)