Browse Source

refactored training without shield

refactoring
sp 10 months ago
parent
commit
2bcb38f6af
  1. 13
      examples/shields/rl/13_minigridsb.py
  2. 6
      examples/shields/rl/sb3utils.py

13
examples/shields/rl/13_minigridsb.py

@ -19,6 +19,9 @@ GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
def mask_fn(env: gym.Env): def mask_fn(env: gym.Env):
return env.create_action_mask() return env.create_action_mask()
def nomask_fn(env: gym.Env):
return [1.0] * 7
def main(): def main():
args = parse_sb3_arguments() args = parse_sb3_arguments()
@ -26,17 +29,21 @@ def main():
formula = args.formula formula = args.formula
shield_value = args.shield_value shield_value = args.shield_value
shield_comparison = args.shield_comparison shield_comparison = args.shield_comparison
logDir = create_log_dir(args)
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison)
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = RGBImgObsWrapper(env) # Get pixel observations env = RGBImgObsWrapper(env) # Get pixel observations
env = ImgObsWrapper(env) # Get rid of the 'mission' field env = ImgObsWrapper(env) # Get rid of the 'mission' field
env = MiniWrapper(env) env = MiniWrapper(env)
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
logDir = create_log_dir(args)
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training:
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
env = ActionMasker(env, mask_fn)
else:
env = ActionMasker(env, nomask_fn)
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto")
evalCallback = EvalCallback(env, best_model_save_path=logDir, evalCallback = EvalCallback(env, best_model_save_path=logDir,
log_path=logDir, eval_freq=max(500, int(args.steps/30)), log_path=logDir, eval_freq=max(500, int(args.steps/30)),
deterministic=True, render=False) deterministic=True, render=False)

6
examples/shields/rl/sb3utils.py

@ -3,7 +3,7 @@ import numpy as np
import random import random
from utils import MiniGridShieldHandler, common_parser from utils import MiniGridShieldHandler, common_parser
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.logger import Image from stable_baselines3.common.logger import Image
class MiniGridSbShieldingWrapper(gym.core.Wrapper): class MiniGridSbShieldingWrapper(gym.core.Wrapper):
@ -11,11 +11,9 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
env, env,
shield_handler : MiniGridShieldHandler, shield_handler : MiniGridShieldHandler,
create_shield_at_reset = True, create_shield_at_reset = True,
mask_actions=True,
): ):
super().__init__(env) super().__init__(env)
self.shield_handler = shield_handler self.shield_handler = shield_handler
self.mask_actions = mask_actions
self.create_shield_at_reset = create_shield_at_reset self.create_shield_at_reset = create_shield_at_reset
shield = self.shield_handler.create_shield(env=self.env) shield = self.shield_handler.create_shield(env=self.env)
@ -30,7 +28,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset and self.mask_actions:
if self.create_shield_at_reset:
shield = self.shield_handler.create_shield(env=self.env) shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield self.shield = shield
return obs, infos return obs, infos

Loading…
Cancel
Save