From d7e7a2411b0bff4b73c6da2012bda73ba8422a75 Mon Sep 17 00:00:00 2001 From: sp Date: Tue, 16 Jan 2024 19:43:16 +0100 Subject: [PATCH] use shield in evaluation when full shielding --- examples/shields/rl/13_minigridsb.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 7d95ffb..578bf2e 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,5 +1,5 @@ from sb3_contrib import MaskablePPO -from sb3_contrib.common.maskable.evaluation import evaluate_policy +from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback from sb3_contrib.common.wrappers import ActionMasker from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat @@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper import time -from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname +from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback from stable_baselines3.common.callbacks import EvalCallback @@ -35,8 +35,11 @@ def main(): log_dir = create_log_dir(args) new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) - if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training or args.shielding == ShieldingConfig.Evaluation: + + if shield_needed(args.shielding): shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) + + env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env) env = ImgObsWrapper(env) @@ -63,12 +66,20 @@ def main(): assert(False) # TODO Do something proper model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") model.set_logger(new_logger) - - evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, - log_path=log_dir, eval_freq=max(500, int(args.steps/30)), - deterministic=True, render=False, n_eval_episodes=5) steps = args.steps + eval_freq=max(500, int(args.steps/30)) + n_eval_episodes=5 + if shielded_evaluation(args.shielding): + evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, + log_path=log_dir, eval_freq=eval_freq, + deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + else: + evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, + log_path=log_dir, eval_freq=eval_freq, + deterministic=True, render=False, n_eval_episodes=n_eval_episodes) + + model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) #vec_env = model.get_env()