| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -63,6 +63,7 @@ def register_minigrid_shielding_env(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    train_batch_size = 4000 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_minigrid_shielding_env(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = (PPOConfig() | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -77,17 +78,17 @@ def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "logdir": create_log_dir(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        })     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # .exploration(exploration_config={"exploration_fraction": 0.1}) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .training(_enable_learner_api=False ,model={ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model": "shielding_model" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        })) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .training(_enable_learner_api=False , | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            model={"custom_model": "shielding_model"}, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            train_batch_size=train_batch_size)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # config.entropy_coeff =  0.05 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    algo =(    | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        config.build() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    )    | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i in range(args.evaluations): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    iterations = int((args.steps / train_batch_size)) + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i in range(iterations): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        result = algo.train() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print(pretty_print(result)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -99,6 +100,7 @@ def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def dqn(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    train_batch_size = 4000 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_minigrid_shielding_env(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -113,15 +115,16 @@ def dqn(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "type": TBXLogger,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "logdir": create_log_dir(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.training(hiddens=[], dueling=False, model={     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model": "shielding_model" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    algo = ( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        config.build() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					          | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i in range(args.evaluations): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    iterations = int((args.steps / train_batch_size)) + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i in range(iterations): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        result = algo.train() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print(pretty_print(result)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |