|  | @ -63,6 +63,7 @@ def register_minigrid_shielding_env(args): | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def ppo(args): |  |  | def ppo(args): | 
		
	
		
			
				|  |  |  |  |  |     train_batch_size = 4000 | 
		
	
		
			
				|  |  |     register_minigrid_shielding_env(args) |  |  |     register_minigrid_shielding_env(args) | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
		
			
				|  |  |     config = (PPOConfig() |  |  |     config = (PPOConfig() | 
		
	
	
		
			
				|  | @ -77,17 +78,17 @@ def ppo(args): | 
		
	
		
			
				|  |  |             "logdir": create_log_dir(args) |  |  |             "logdir": create_log_dir(args) | 
		
	
		
			
				|  |  |         })     |  |  |         })     | 
		
	
		
			
				|  |  |         # .exploration(exploration_config={"exploration_fraction": 0.1}) |  |  |         # .exploration(exploration_config={"exploration_fraction": 0.1}) | 
		
	
		
			
				|  |  |         .training(_enable_learner_api=False ,model={ |  |  |  | 
		
	
		
			
				|  |  |             "custom_model": "shielding_model" |  |  |  | 
		
	
		
			
				|  |  |         })) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         .training(_enable_learner_api=False , | 
		
	
		
			
				|  |  |  |  |  |             model={"custom_model": "shielding_model"}, | 
		
	
		
			
				|  |  |  |  |  |             train_batch_size=train_batch_size)) | 
		
	
		
			
				|  |  |     # config.entropy_coeff =  0.05 |  |  |     # config.entropy_coeff =  0.05 | 
		
	
		
			
				|  |  |     algo =(    |  |  |     algo =(    | 
		
	
		
			
				|  |  |         config.build() |  |  |         config.build() | 
		
	
		
			
				|  |  |     )    |  |  |     )    | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
		
			
				|  |  | 
 |  |  |  | 
		
	
		
			
				|  |  |     for i in range(args.evaluations): |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     iterations = int((args.steps / train_batch_size)) + 1 | 
		
	
		
			
				|  |  |  |  |  |     for i in range(iterations): | 
		
	
		
			
				|  |  |         result = algo.train() |  |  |         result = algo.train() | 
		
	
		
			
				|  |  |         print(pretty_print(result)) |  |  |         print(pretty_print(result)) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | @ -99,6 +100,7 @@ def ppo(args): | 
		
	
		
			
				|  |  |              |  |  |              | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def dqn(args): |  |  | def dqn(args): | 
		
	
		
			
				|  |  |  |  |  |     train_batch_size = 4000 | 
		
	
		
			
				|  |  |     register_minigrid_shielding_env(args) |  |  |     register_minigrid_shielding_env(args) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
	
		
			
				|  | @ -113,7 +115,7 @@ def dqn(args): | 
		
	
		
			
				|  |  |             "type": TBXLogger,  |  |  |             "type": TBXLogger,  | 
		
	
		
			
				|  |  |             "logdir": create_log_dir(args) |  |  |             "logdir": create_log_dir(args) | 
		
	
		
			
				|  |  |         }) |  |  |         }) | 
		
	
		
			
				|  |  |     config = config.training(hiddens=[], dueling=False, model={     |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={     | 
		
	
		
			
				|  |  |             "custom_model": "shielding_model" |  |  |             "custom_model": "shielding_model" | 
		
	
		
			
				|  |  |     }) |  |  |     }) | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
	
		
			
				|  | @ -121,7 +123,8 @@ def dqn(args): | 
		
	
		
			
				|  |  |         config.build() |  |  |         config.build() | 
		
	
		
			
				|  |  |     ) |  |  |     ) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |     for i in range(args.evaluations): |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     iterations = int((args.steps / train_batch_size)) + 1 | 
		
	
		
			
				|  |  |  |  |  |     for i in range(iterations): | 
		
	
		
			
				|  |  |         result = algo.train() |  |  |         result = algo.train() | 
		
	
		
			
				|  |  |         print(pretty_print(result)) |  |  |         print(pretty_print(result)) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | 
 |