You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							75 lines
						
					
					
						
							2.2 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							75 lines
						
					
					
						
							2.2 KiB
						
					
					
				
								from __future__ import annotations
							 | 
						|
								
							 | 
						|
								import gymnasium as gym
							 | 
						|
								import pytest
							 | 
						|
								
							 | 
						|
								from minigrid.core.constants import COLOR_NAMES
							 | 
						|
								from minigrid.core.world_object import Ball, Box
							 | 
						|
								
							 | 
						|
								TESTING_ENVS = [
							 | 
						|
								    "MiniGrid-ObstructedMaze-2Dlhb",
							 | 
						|
								    "MiniGrid-ObstructedMaze-1Q",
							 | 
						|
								    "MiniGrid-ObstructedMaze-2Q",
							 | 
						|
								    "MiniGrid-ObstructedMaze-Full",
							 | 
						|
								]
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								def find_ball_room(env):
							 | 
						|
								    for obj in env.grid.grid:
							 | 
						|
								        if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
							 | 
						|
								            return env.room_from_pos(*obj.cur_pos)
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								def find_target_key(env, color):
							 | 
						|
								    for obj in env.grid.grid:
							 | 
						|
								        if isinstance(obj, Box) and obj.contains.color == color:
							 | 
						|
								            return True
							 | 
						|
								    return False
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								def env_test(env_id, repeats=10000):
							 | 
						|
								    env = gym.make(env_id)
							 | 
						|
								
							 | 
						|
								    cnt = 0
							 | 
						|
								    for _ in range(repeats):
							 | 
						|
								        env.reset()
							 | 
						|
								        ball_room = find_ball_room(env)
							 | 
						|
								        ball_room_doors = list(filter(None, ball_room.doors))
							 | 
						|
								        keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
							 | 
						|
								        if not any(keys_exit):
							 | 
						|
								            cnt += 1
							 | 
						|
								
							 | 
						|
								    return (cnt / repeats) * 100
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								@pytest.mark.parametrize("env_id", TESTING_ENVS)
							 | 
						|
								def test_solvable_env(env_id):
							 | 
						|
								    assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								def main():
							 | 
						|
								    """
							 | 
						|
								    Test the frequency of unsolvable situation in this environment, including
							 | 
						|
								    MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
							 | 
						|
								    situation is that in the v0 version of these environments, the box containing
							 | 
						|
								    the key to the door connecting the upper-right room may be covered by the
							 | 
						|
								    blocking ball of the door connecting the lower-right room.
							 | 
						|
								
							 | 
						|
								    Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
							 | 
						|
								    unsolvable situation.
							 | 
						|
								
							 | 
						|
								    Expected probability of unsolvable situation:
							 | 
						|
								    - MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
							 | 
						|
								    - MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
							 | 
						|
								    - MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
							 | 
						|
								    - MiniGrid-ObstructedMaze-Full-v0: 0%
							 | 
						|
								    """
							 | 
						|
								
							 | 
						|
								    for env_id in TESTING_ENVS:
							 | 
						|
								        print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
							 | 
						|
								    for env_id in TESTING_ENVS:
							 | 
						|
								        print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
							 | 
						|
								
							 | 
						|
								
							 | 
						|
								if __name__ == "__main__":
							 | 
						|
								    main()
							 |