diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 38fad3f..f8ebe87 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -31,10 +31,10 @@ def main(): shield_comparison = args.shield_comparison logDir = create_log_dir(args) - shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) + shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, cleanup=args.cleanup) env = gym.make(args.env, render_mode="rgb_array") - env = RGBImgObsWrapper(env) # Get pixel observations - env = ImgObsWrapper(env) # Get rid of the 'mission' field + env = RGBImgObsWrapper(env) + env = ImgObsWrapper(env) env = MiniWrapper(env) if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)