| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -10,6 +10,7 @@ from ray.tune.logger import UnifiedLogger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.models import ModelCatalog | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.algorithms.algorithm import Algorithm | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.algorithm.callbacks import make_multi_callbacks | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.air import session | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from torch_action_mask_model import TorchActionMaskModel | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -78,7 +79,7 @@ def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  },) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .framework("torch") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .callbacks([MyCallbacks, ShieldInfoCallback]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .evaluation(evaluation_config={ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                       "evaluation_interval": 1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "evaluation_duration": 10, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |