|  | @ -19,6 +19,9 @@ GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") | 
		
	
		
			
				|  |  | def mask_fn(env: gym.Env): |  |  | def mask_fn(env: gym.Env): | 
		
	
		
			
				|  |  |     return env.create_action_mask() |  |  |     return env.create_action_mask() | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | def nomask_fn(env: gym.Env): | 
		
	
		
			
				|  |  |  |  |  |     return [1.0] * 7 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def main(): |  |  | def main(): | 
		
	
		
			
				|  |  |     args = parse_sb3_arguments() |  |  |     args = parse_sb3_arguments() | 
		
	
	
		
			
				|  | @ -26,17 +29,21 @@ def main(): | 
		
	
		
			
				|  |  |     formula = args.formula |  |  |     formula = args.formula | 
		
	
		
			
				|  |  |     shield_value = args.shield_value |  |  |     shield_value = args.shield_value | 
		
	
		
			
				|  |  |     shield_comparison = args.shield_comparison |  |  |     shield_comparison = args.shield_comparison | 
		
	
		
			
				|  |  |  |  |  |     logDir = create_log_dir(args) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) |  |  |     shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) | 
		
	
		
			
				|  |  |     env = gym.make(args.env, render_mode="rgb_array") |  |  |     env = gym.make(args.env, render_mode="rgb_array") | 
		
	
		
			
				|  |  |     env = RGBImgObsWrapper(env) # Get pixel observations |  |  |     env = RGBImgObsWrapper(env) # Get pixel observations | 
		
	
		
			
				|  |  |     env = ImgObsWrapper(env) # Get rid of the 'mission' field |  |  |     env = ImgObsWrapper(env) # Get rid of the 'mission' field | 
		
	
		
			
				|  |  |     env = MiniWrapper(env) |  |  |     env = MiniWrapper(env) | 
		
	
		
			
				|  |  |     env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: | 
		
	
		
			
				|  |  |  |  |  |         env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) | 
		
	
		
			
				|  |  |         env = ActionMasker(env, mask_fn) |  |  |         env = ActionMasker(env, mask_fn) | 
		
	
		
			
				|  |  |     logDir = create_log_dir(args) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     else: | 
		
	
		
			
				|  |  |  |  |  |         env = ActionMasker(env, nomask_fn) | 
		
	
		
			
				|  |  |     model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") |  |  |     model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=logDir, device="auto") | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |     evalCallback = EvalCallback(env, best_model_save_path=logDir, |  |  |     evalCallback = EvalCallback(env, best_model_save_path=logDir, | 
		
	
		
			
				|  |  |                                 log_path=logDir, eval_freq=max(500,  int(args.steps/30)), |  |  |                                 log_path=logDir, eval_freq=max(500,  int(args.steps/30)), | 
		
	
		
			
				|  |  |                                 deterministic=True, render=False) |  |  |                                 deterministic=True, render=False) | 
		
	
	
		
			
				|  | 
 |