| 
					
					
						
							
						
					
					
				 | 
				@ -68,34 +68,34 @@ def trial_name_creator(trial : Trial): | 
			
		
		
	
		
			
				 | 
				 | 
				def ppo(args): | 
				 | 
				 | 
				def ppo(args): | 
			
		
		
	
		
			
				 | 
				 | 
				    register_minigrid_shielding_env(args) | 
				 | 
				 | 
				    register_minigrid_shielding_env(args) | 
			
		
		
	
		
			
				 | 
				 | 
				    logdir = args.log_dir | 
				 | 
				 | 
				    logdir = args.log_dir | 
			
		
		
	
		
			
				 | 
				 | 
				     | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    config = (PPOConfig() | 
				 | 
				 | 
				    config = (PPOConfig() | 
			
		
		
	
		
			
				 | 
				 | 
				        .rollouts(num_rollout_workers=args.workers) | 
				 | 
				 | 
				        .rollouts(num_rollout_workers=args.workers) | 
			
		
		
	
		
			
				 | 
				 | 
				        .resources(num_gpus=0) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        .resources(num_gpus=args.num_gpus) | 
			
		
		
	
		
			
				 | 
				 | 
				        .environment( env="mini-grid-shielding", | 
				 | 
				 | 
				        .environment( env="mini-grid-shielding", | 
			
		
		
	
		
			
				 | 
				 | 
				                      env_config={"name": args.env, | 
				 | 
				 | 
				                      env_config={"name": args.env, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  "args": args,                                   | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                  "args": args, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
				 | 
				 | 
				                                  "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  },) | 
				 | 
				 | 
				                                  },) | 
			
		
		
	
		
			
				 | 
				 | 
				        .framework("torch") | 
				 | 
				 | 
				        .framework("torch") | 
			
		
		
	
		
			
				 | 
				 | 
				        .callbacks(MyCallbacks) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        .evaluation(evaluation_config={  | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        .callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        .evaluation(evaluation_config={ | 
			
		
		
	
		
			
				 | 
				 | 
				                                       "evaluation_interval": 1, | 
				 | 
				 | 
				                                       "evaluation_interval": 1, | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "evaluation_duration": 10, | 
				 | 
				 | 
				                                        "evaluation_duration": 10, | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "evaluation_num_workers":1, | 
				 | 
				 | 
				                                        "evaluation_num_workers":1, | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "env": "mini-grid-shielding",  | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "env_config": {"name": args.env,  | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				                                                       "args": args,  | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				                                                       "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                        "env": "mini-grid-shielding", | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                        "env_config": {"name": args.env, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                                       "args": args, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                                       "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) | 
			
		
		
	
		
			
				 | 
				 | 
				        .rl_module(_enable_rl_module_api = False) | 
				 | 
				 | 
				        .rl_module(_enable_rl_module_api = False) | 
			
		
		
	
		
			
				 | 
				 | 
				        .debugging(logger_config={ | 
				 | 
				 | 
				        .debugging(logger_config={ | 
			
		
		
	
		
			
				 | 
				 | 
				            "type": UnifiedLogger,  | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            "type": UnifiedLogger, | 
			
		
		
	
		
			
				 | 
				 | 
				            "logdir": logdir | 
				 | 
				 | 
				            "logdir": logdir | 
			
		
		
	
		
			
				 | 
				 | 
				        }) | 
				 | 
				 | 
				        }) | 
			
		
		
	
		
			
				 | 
				 | 
				        .training(_enable_learner_api=False ,model={ | 
				 | 
				 | 
				        .training(_enable_learner_api=False ,model={ | 
			
		
		
	
		
			
				 | 
				 | 
				            "custom_model": "shielding_model"       | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            "custom_model": "shielding_model" | 
			
		
		
	
		
			
				 | 
				 | 
				        })) | 
				 | 
				 | 
				        })) | 
			
		
		
	
		
			
				 | 
				 | 
				     | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    tuner = tune.Tuner("PPO", | 
				 | 
				 | 
				    tuner = tune.Tuner("PPO", | 
			
		
		
	
		
			
				 | 
				 | 
				                       tune_config=tune.TuneConfig( | 
				 | 
				 | 
				                       tune_config=tune.TuneConfig( | 
			
		
		
	
		
			
				 | 
				 | 
				                           metric="episode_reward_mean", | 
				 | 
				 | 
				                           metric="episode_reward_mean", | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |