|  |  | @ -1,5 +1,6 @@ | 
			
		
	
		
			
				
					|  |  |  | import gymnasium as gym | 
			
		
	
		
			
				
					|  |  |  | import numpy as np | 
			
		
	
		
			
				
					|  |  |  | import random | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from minigrid.core.actions import Actions | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -15,9 +16,9 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, env, vector_index, framestack): | 
			
		
	
		
			
				
					|  |  |  |         super().__init__(env) | 
			
		
	
		
			
				
					|  |  |  |         self.framestack = framestack | 
			
		
	
		
			
				
					|  |  |  |         # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. | 
			
		
	
		
			
				
					|  |  |  |         # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. | 
			
		
	
		
			
				
					|  |  |  |         # +4: Direction. | 
			
		
	
		
			
				
					|  |  |  |         self.single_frame_dim = 49 * (11 + 6 + 3) + 4 | 
			
		
	
		
			
				
					|  |  |  |         self.single_frame_dim = 49 * (16 + 6 + 3) + 4 | 
			
		
	
		
			
				
					|  |  |  |         self.init_x = None | 
			
		
	
		
			
				
					|  |  |  |         self.init_y = None | 
			
		
	
		
			
				
					|  |  |  |         self.x_positions = [] | 
			
		
	
	
		
			
				
					|  |  | @ -66,8 +67,8 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         image = obs["data"] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. | 
			
		
	
		
			
				
					|  |  |  |         objects = one_hot(image[:, :, 0], depth=11) | 
			
		
	
		
			
				
					|  |  |  |         # One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten. | 
			
		
	
		
			
				
					|  |  |  |         objects = one_hot(image[:, :, 0], depth=16) | 
			
		
	
		
			
				
					|  |  |  |         colors = one_hot(image[:, :, 1], depth=6) | 
			
		
	
		
			
				
					|  |  |  |         states = one_hot(image[:, :, 2], depth=3) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -115,12 +116,15 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): | 
			
		
	
		
			
				
					|  |  |  |         mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         if cur_pos_str in self.shield and self.shield[cur_pos_str]: | 
			
		
	
		
			
				
					|  |  |  |              allowed_actions = self.shield[cur_pos_str] | 
			
		
	
		
			
				
					|  |  |  |              for allowed_action in allowed_actions: | 
			
		
	
		
			
				
					|  |  |  |                  index =  get_action_index_mapping(allowed_action[1]) # Allowed_action is a set | 
			
		
	
		
			
				
					|  |  |  |                  if index is None: | 
			
		
	
		
			
				
					|  |  |  |                      assert(False) | 
			
		
	
		
			
				
					|  |  |  |                  mask[index] = 1.0 | 
			
		
	
		
			
				
					|  |  |  |             allowed_actions = self.shield[cur_pos_str] | 
			
		
	
		
			
				
					|  |  |  |             for allowed_action in allowed_actions: | 
			
		
	
		
			
				
					|  |  |  |                 index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set | 
			
		
	
		
			
				
					|  |  |  |                 if index is None: | 
			
		
	
		
			
				
					|  |  |  |                     assert(False) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |                 allowed =  random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] | 
			
		
	
		
			
				
					|  |  |  |                 mask[index] = allowed                | 
			
		
	
		
			
				
					|  |  |  |                       | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             for index, x in enumerate(mask): | 
			
		
	
		
			
				
					|  |  |  |                 mask[index] = 1.0 | 
			
		
	
	
		
			
				
					|  |  | @ -195,13 +199,13 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): | 
			
		
	
		
			
				
					|  |  |  |         mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         if cur_pos_str in self.shield and self.shield[cur_pos_str]: | 
			
		
	
		
			
				
					|  |  |  |              allowed_actions = self.shield[cur_pos_str] | 
			
		
	
		
			
				
					|  |  |  |              for allowed_action in allowed_actions: | 
			
		
	
		
			
				
					|  |  |  |                  index =  get_action_index_mapping(allowed_action[1]) | 
			
		
	
		
			
				
					|  |  |  |                  if index is None: | 
			
		
	
		
			
				
					|  |  |  |             allowed_actions = self.shield[cur_pos_str] | 
			
		
	
		
			
				
					|  |  |  |             for allowed_action in allowed_actions: | 
			
		
	
		
			
				
					|  |  |  |                 index =  get_action_index_mapping(allowed_action.labels) | 
			
		
	
		
			
				
					|  |  |  |                 if index is None: | 
			
		
	
		
			
				
					|  |  |  |                      assert(False) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |                  mask[index] = 1.0 | 
			
		
	
		
			
				
					|  |  |  |                 mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] | 
			
		
	
		
			
				
					|  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |             for index, x in enumerate(mask): | 
			
		
	
		
			
				
					|  |  |  |                 mask[index] = 1.0 | 
			
		
	
	
		
			
				
					|  |  | 
 |