|
|
@ -11,6 +11,7 @@ import time |
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper |
|
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
|
|
|
|
|
import os |
|
|
|
|
|
|
@ -33,27 +34,29 @@ def main(): |
|
|
|
env = MiniWrapper(env) |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args)) |
|
|
|
logDir = create_log_dir(args) |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir) |
|
|
|
|
|
|
|
evalCallback = EvalCallback(env, best_model_save_path=logDir, |
|
|
|
log_path=logDir, eval_freq=max(500, int(args.steps/30)), |
|
|
|
deterministic=True, render=False) |
|
|
|
steps = args.steps |
|
|
|
|
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1) |
|
|
|
|
|
|
|
|
|
|
|
print("Learning done, hit enter") |
|
|
|
input("") |
|
|
|
vec_env = model.get_env() |
|
|
|
obs = vec_env.reset() |
|
|
|
terminated = truncated = False |
|
|
|
while not terminated and not truncated: |
|
|
|
action_masks = None |
|
|
|
action, _states = model.predict(obs, action_masks=action_masks) |
|
|
|
print(action) |
|
|
|
obs, reward, terminated, truncated, info = env.step(action) |
|
|
|
# action, _states = model.predict(obs, deterministic=True) |
|
|
|
# obs, rewards, dones, info = vec_env.step(action) |
|
|
|
vec_env.render("human") |
|
|
|
time.sleep(0.2) |
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()]) |
|
|
|
|
|
|
|
|
|
|
|
#vec_env = model.get_env() |
|
|
|
#obs = vec_env.reset() |
|
|
|
#terminated = truncated = False |
|
|
|
#while not terminated and not truncated: |
|
|
|
# action_masks = None |
|
|
|
# action, _states = model.predict(obs, action_masks=action_masks) |
|
|
|
# print(action) |
|
|
|
# obs, reward, terminated, truncated, info = env.step(action) |
|
|
|
# # action, _states = model.predict(obs, deterministic=True) |
|
|
|
# # obs, rewards, dones, info = vec_env.step(action) |
|
|
|
# vec_env.render("human") |
|
|
|
# time.sleep(0.2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|