|  | @ -1,6 +1,7 @@ | 
		
	
		
			
				|  |  | from sb3_contrib import MaskablePPO |  |  | from sb3_contrib import MaskablePPO | 
		
	
		
			
				|  |  | from sb3_contrib.common.maskable.evaluation import evaluate_policy |  |  | from sb3_contrib.common.maskable.evaluation import evaluate_policy | 
		
	
		
			
				|  |  | from sb3_contrib.common.wrappers import ActionMasker |  |  | from sb3_contrib.common.wrappers import ActionMasker | 
		
	
		
			
				|  |  |  |  |  | from stable_baselines3.common.logger import configure | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import gymnasium as gym |  |  | import gymnasium as gym | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | @ -30,6 +31,7 @@ def main(): | 
		
	
		
			
				|  |  |     shield_value = args.shield_value |  |  |     shield_value = args.shield_value | 
		
	
		
			
				|  |  |     shield_comparison = args.shield_comparison |  |  |     shield_comparison = args.shield_comparison | 
		
	
		
			
				|  |  |     log_dir = create_log_dir(args) |  |  |     log_dir = create_log_dir(args) | 
		
	
		
			
				|  |  |  |  |  |     new_logger = configure(log_dir, ["csv", "tensorboard"]) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     env = gym.make(args.env, render_mode="rgb_array") |  |  |     env = gym.make(args.env, render_mode="rgb_array") | 
		
	
		
			
				|  |  |     env = RGBImgObsWrapper(env) |  |  |     env = RGBImgObsWrapper(env) | 
		
	
	
		
			
				|  | @ -42,6 +44,7 @@ def main(): | 
		
	
		
			
				|  |  |     else: |  |  |     else: | 
		
	
		
			
				|  |  |         env = ActionMasker(env, nomask_fn) |  |  |         env = ActionMasker(env, nomask_fn) | 
		
	
		
			
				|  |  |     model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") |  |  |     model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") | 
		
	
		
			
				|  |  |  |  |  |     model.set_logger(new_logger) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     evalCallback = EvalCallback(env, best_model_save_path=log_dir, |  |  |     evalCallback = EvalCallback(env, best_model_save_path=log_dir, | 
		
	
	
		
			
				|  | 
 |