|  | @ -1,7 +1,7 @@ | 
		
	
		
			
				|  |  | from sb3_contrib import MaskablePPO |  |  | from sb3_contrib import MaskablePPO | 
		
	
		
			
				|  |  | from sb3_contrib.common.maskable.evaluation import evaluate_policy |  |  | from sb3_contrib.common.maskable.evaluation import evaluate_policy | 
		
	
		
			
				|  |  | from sb3_contrib.common.wrappers import ActionMasker |  |  | from sb3_contrib.common.wrappers import ActionMasker | 
		
	
		
			
				|  |  | from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat |  |  |  | 
		
	
		
			
				|  |  |  |  |  | from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import gymnasium as gym |  |  | import gymnasium as gym | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | @ -14,7 +14,7 @@ from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWr | 
		
	
		
			
				|  |  | from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |  |  | from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback | 
		
	
		
			
				|  |  | from stable_baselines3.common.callbacks import EvalCallback |  |  | from stable_baselines3.common.callbacks import EvalCallback | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import os |  |  |  | 
		
	
		
			
				|  |  |  |  |  | import os, sys | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") |  |  | GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") | 
		
	
		
			
				|  |  | def mask_fn(env: gym.Env): |  |  | def mask_fn(env: gym.Env): | 
		
	
	
		
			
				|  | @ -31,7 +31,7 @@ def main(): | 
		
	
		
			
				|  |  |     shield_value = args.shield_value |  |  |     shield_value = args.shield_value | 
		
	
		
			
				|  |  |     shield_comparison = args.shield_comparison |  |  |     shield_comparison = args.shield_comparison | 
		
	
		
			
				|  |  |     log_dir = create_log_dir(args) |  |  |     log_dir = create_log_dir(args) | 
		
	
		
			
				|  |  |     new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     env = gym.make(args.env, render_mode="rgb_array") |  |  |     env = gym.make(args.env, render_mode="rgb_array") | 
		
	
		
			
				|  |  |     env = RGBImgObsWrapper(env) |  |  |     env = RGBImgObsWrapper(env) | 
		
	
	
		
			
				|  | 
 |