Browse Source

init evalCallback for training with sb3

refactoring
sp 1 year ago
parent
commit
71854bae01
  1. 39
      examples/shields/rl/13_minigridsb.py

39
examples/shields/rl/13_minigridsb.py

@ -11,6 +11,7 @@ import time
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
from stable_baselines3.common.callbacks import EvalCallback
import os
@ -33,27 +34,29 @@ def main():
env = MiniWrapper(env)
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args))
logDir = create_log_dir(args)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir)
evalCallback = EvalCallback(env, best_model_save_path=logDir,
log_path=logDir, eval_freq=max(500, int(args.steps/30)),
deterministic=True, render=False)
steps = args.steps
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1)
print("Learning done, hit enter")
input("")
vec_env = model.get_env()
obs = vec_env.reset()
terminated = truncated = False
while not terminated and not truncated:
action_masks = None
action, _states = model.predict(obs, action_masks=action_masks)
print(action)
obs, reward, terminated, truncated, info = env.step(action)
# action, _states = model.predict(obs, deterministic=True)
# obs, rewards, dones, info = vec_env.step(action)
vec_env.render("human")
time.sleep(0.2)
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()])
#vec_env = model.get_env()
#obs = vec_env.reset()
#terminated = truncated = False
#while not terminated and not truncated:
# action_masks = None
# action, _states = model.predict(obs, action_masks=action_masks)
# print(action)
# obs, reward, terminated, truncated, info = env.step(action)
# # action, _states = model.predict(obs, deterministic=True)
# # obs, rewards, dones, info = vec_env.step(action)
# vec_env.render("human")
# time.sleep(0.2)

Loading…
Cancel
Save