From 2bcb38f6afb395dfec16b3507974e0c9474037f5 Mon Sep 17 00:00:00 2001 From: sp Date: Mon, 15 Jan 2024 10:52:20 +0100 Subject: [PATCH] refactored training without shield --- examples/shields/rl/13_minigridsb.py | 13 ++++++++++--- examples/shields/rl/sb3utils.py | 6 ++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index c842d15..38fad3f 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -19,6 +19,9 @@ GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") def mask_fn(env: gym.Env): return env.create_action_mask() +def nomask_fn(env: gym.Env): + return [1.0] * 7 + def main(): args = parse_sb3_arguments() @@ -26,17 +29,21 @@ def main(): formula = args.formula shield_value = args.shield_value shield_comparison = args.shield_comparison + logDir = create_log_dir(args) shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) env = gym.make(args.env, render_mode="rgb_array") env = RGBImgObsWrapper(env) # Get pixel observations env = ImgObsWrapper(env) # Get rid of the 'mission' field env = MiniWrapper(env) - env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) - env = ActionMasker(env, mask_fn) - logDir = create_log_dir(args) + if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: + env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) + env = ActionMasker(env, mask_fn) + else: + env = ActionMasker(env, nomask_fn) model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") + evalCallback = EvalCallback(env, best_model_save_path=logDir, log_path=logDir, eval_freq=max(500, int(args.steps/30)), deterministic=True, render=False) diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index f260c62..454552b 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -3,7 +3,7 @@ import numpy as np import random from utils import MiniGridShieldHandler, common_parser -from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback +from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback from stable_baselines3.common.logger import Image class MiniGridSbShieldingWrapper(gym.core.Wrapper): @@ -11,11 +11,9 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): env, shield_handler : MiniGridShieldHandler, create_shield_at_reset = True, - mask_actions=True, ): super().__init__(env) self.shield_handler = shield_handler - self.mask_actions = mask_actions self.create_shield_at_reset = create_shield_at_reset shield = self.shield_handler.create_shield(env=self.env) @@ -30,7 +28,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - if self.create_shield_at_reset and self.mask_actions: + if self.create_shield_at_reset: shield = self.shield_handler.create_shield(env=self.env) self.shield = shield return obs, infos