From 028c94262503abc0d0e07d9738558243a3b0197a Mon Sep 17 00:00:00 2001 From: sp Date: Tue, 16 Jan 2024 17:48:29 +0100 Subject: [PATCH] evaluate sb3 training WIP: Not a 100% sure whether the masking will be used in the evaluation --- examples/shields/rl/13_minigridsb.py | 33 +++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index f411aaf..7d95ffb 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -15,6 +15,7 @@ from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecor from stable_baselines3.common.callbacks import EvalCallback import os, sys +from copy import deepcopy GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") def mask_fn(env: gym.Env): @@ -27,32 +28,48 @@ def nomask_fn(env: gym.Env): def main(): args = parse_sb3_arguments() + formula = args.formula shield_value = args.shield_value shield_comparison = args.shield_comparison log_dir = create_log_dir(args) - new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) + new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) + if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training or args.shielding == ShieldingConfig.Evaluation: + shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env) env = ImgObsWrapper(env) env = MiniWrapper(env) - if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: - shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) + eval_env = deepcopy(env) + eval_env.disable_random_start() + if args.shielding == ShieldingConfig.Full: env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) env = ActionMasker(env, mask_fn) - else: + eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False) + eval_env = ActionMasker(eval_env, mask_fn) + elif args.shielding == ShieldingConfig.Training: + env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) + env = ActionMasker(env, mask_fn) + eval_env = ActionMasker(eval_env, nomask_fn) + elif args.shielding == ShieldingConfig.Evaluation: + env = ActionMasker(env, nomask_fn) + eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False) + eval_env = ActionMasker(eval_env, mask_fn) + elif args.shielding == ShieldingConfig.Disabled: env = ActionMasker(env, nomask_fn) + eval_env = ActionMasker(eval_env, nomask_fn) + else: + assert(False) # TODO Do something proper model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") model.set_logger(new_logger) - - evalCallback = EvalCallback(env, best_model_save_path=log_dir, + evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, log_path=log_dir, eval_freq=max(500, int(args.steps/30)), - deterministic=True, render=False) + deterministic=True, render=False, n_eval_episodes=5) steps = args.steps - model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()]) + model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) #vec_env = model.get_env() #obs = vec_env.reset()