| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -68,34 +68,34 @@ def trial_name_creator(trial : Trial): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_minigrid_shielding_env(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logdir = args.log_dir | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = (PPOConfig() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .rollouts(num_rollout_workers=args.workers) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .resources(num_gpus=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .resources(num_gpus=args.num_gpus) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .environment( env="mini-grid-shielding", | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                      env_config={"name": args.env, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  "args": args,                                   | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  "args": args, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                  },) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .framework("torch") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .callbacks(MyCallbacks) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .evaluation(evaluation_config={  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .evaluation(evaluation_config={ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                       "evaluation_interval": 1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "evaluation_duration": 10, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "evaluation_num_workers":1, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "env": "mini-grid-shielding",  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "env_config": {"name": args.env,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                                       "args": args,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                                       "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "env": "mini-grid-shielding", | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                        "env_config": {"name": args.env, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                                       "args": args, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                                                       "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .rl_module(_enable_rl_module_api = False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .debugging(logger_config={ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "type": UnifiedLogger,  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "type": UnifiedLogger, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "logdir": logdir | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .training(_enable_learner_api=False ,model={ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model": "shielding_model"       | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model": "shielding_model" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        })) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    tuner = tune.Tuner("PPO", | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                       tune_config=tune.TuneConfig( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                           metric="episode_reward_mean", | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |