|  |  | @ -28,16 +28,29 @@ class ShieldingConfig(Enum): | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def extract_keys(env): | 
			
		
	
		
			
				
					|  |  |  |     keys = [] | 
			
		
	
		
			
				
					|  |  |  |     #print(env.grid) | 
			
		
	
		
			
				
					|  |  |  |     for j in range(env.grid.height): | 
			
		
	
		
			
				
					|  |  |  |         for i in range(env.grid.width): | 
			
		
	
		
			
				
					|  |  |  |             obj = env.grid.get(i,j) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             if obj and obj.type == "key": | 
			
		
	
		
			
				
					|  |  |  |                 keys.append(obj.color) | 
			
		
	
		
			
				
					|  |  |  |                 keys.append((obj, i, j)) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     if env.carrying and env.carrying.type == "key": | 
			
		
	
		
			
				
					|  |  |  |         keys.append((env.carrying, -1, -1)) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     return keys | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def extract_doors(env): | 
			
		
	
		
			
				
					|  |  |  |     doors = [] | 
			
		
	
		
			
				
					|  |  |  |     for j in range(env.grid.height): | 
			
		
	
		
			
				
					|  |  |  |         for i in range(env.grid.width): | 
			
		
	
		
			
				
					|  |  |  |             obj = env.grid.get(i,j) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |             if obj and obj.type == "door": | 
			
		
	
		
			
				
					|  |  |  |                 doors.append(obj) | 
			
		
	
		
			
				
					|  |  |  |                  | 
			
		
	
		
			
				
					|  |  |  |     return doors | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def create_log_dir(args): | 
			
		
	
		
			
				
					|  |  |  |     return F"{args.log_dir}sh:{args.shielding}-env:{args.env}" | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -56,6 +69,8 @@ def get_action_index_mapping(actions): | 
			
		
	
		
			
				
					|  |  |  |             return Actions.pickup | 
			
		
	
		
			
				
					|  |  |  |         elif "done" in action_str: | 
			
		
	
		
			
				
					|  |  |  |             return Actions.done     | 
			
		
	
		
			
				
					|  |  |  |         elif "drop" in action_str: | 
			
		
	
		
			
				
					|  |  |  |             return Actions.drop | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     raise ValueError(F"Action string {action_str} not supported") | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -74,7 +89,11 @@ def parse_arguments(argparse): | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-LavaSlipperyS12-v1", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-LavaSlipperyS12-v2", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-LavaSlipperyS12-v3", | 
			
		
	
		
			
				
					|  |  |  |                                 # "MiniGrid-DoorKey-8x8-v0",  | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-DoorKey-8x8-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-DoubleDoor-16x16-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-DoubleDoor-12x12-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-DoubleDoor-10x8-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 "MiniGrid-SingleDoor-7x6-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 # "MiniGrid-LockedRoom-v0", | 
			
		
	
		
			
				
					|  |  |  |                                 # "MiniGrid-FourRooms-v0",  | 
			
		
	
		
			
				
					|  |  |  |                                 # "MiniGrid-LavaGapS7-v0", | 
			
		
	
	
		
			
				
					|  |  | @ -95,6 +114,7 @@ def parse_arguments(argparse): | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--steps", default=20_000, type=int) | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--expname", default="exp") | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) | 
			
		
	
		
			
				
					|  |  |  |     args = parser.parse_args() | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     return args |