|  |  | @ -104,6 +104,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): | 
			
		
	
		
			
				
					|  |  |  |         print(F"Shielding is {self.mask_actions}") | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def create_action_mask(self): | 
			
		
	
		
			
				
					|  |  |  |         print(f'shielding is {self.mask_actions}') | 
			
		
	
		
			
				
					|  |  |  |         if not self.mask_actions: | 
			
		
	
		
			
				
					|  |  |  |             ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) | 
			
		
	
		
			
				
					|  |  |  |             return ret | 
			
		
	
	
		
			
				
					|  |  | @ -185,9 +186,18 @@ def shielding_env_creater(config): | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     probability_intended = args.probability_intended | 
			
		
	
		
			
				
					|  |  |  |     probability_displacement = args.probability_displacement | 
			
		
	
		
			
				
					|  |  |  |     probability_turn_intended = args.probability_turn_intended | 
			
		
	
		
			
				
					|  |  |  |     probability_turn_displacement = args.probability_turn_displacement | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name, | 
			
		
	
		
			
				
					|  |  |  |                   randomize_start=True, | 
			
		
	
		
			
				
					|  |  |  |                   probability_intended=probability_intended, | 
			
		
	
		
			
				
					|  |  |  |                   probability_displacement=probability_displacement,  | 
			
		
	
		
			
				
					|  |  |  |                   probability_turn_displacement=probability_turn_displacement, | 
			
		
	
		
			
				
					|  |  |  |                   probability_turn_intended=probability_turn_intended) | 
			
		
	
		
			
				
					|  |  |  |                    | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     env = OneHotShieldingWrapper(env, | 
			
		
	
		
			
				
					|  |  |  |                         config.vector_index if hasattr(config, "vector_index") else 0, | 
			
		
	
	
		
			
				
					|  |  | 
 |