You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

75 lines
2.2 KiB

from __future__ import annotations
import gymnasium as gym
import pytest
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.world_object import Ball, Box
TESTING_ENVS = [
"MiniGrid-ObstructedMaze-2Dlhb",
"MiniGrid-ObstructedMaze-1Q",
"MiniGrid-ObstructedMaze-2Q",
"MiniGrid-ObstructedMaze-Full",
]
def find_ball_room(env):
for obj in env.grid.grid:
if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
return env.room_from_pos(*obj.cur_pos)
def find_target_key(env, color):
for obj in env.grid.grid:
if isinstance(obj, Box) and obj.contains.color == color:
return True
return False
def env_test(env_id, repeats=10000):
env = gym.make(env_id)
cnt = 0
for _ in range(repeats):
env.reset()
ball_room = find_ball_room(env)
ball_room_doors = list(filter(None, ball_room.doors))
keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
if not any(keys_exit):
cnt += 1
return (cnt / repeats) * 100
@pytest.mark.parametrize("env_id", TESTING_ENVS)
def test_solvable_env(env_id):
assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
def main():
"""
Test the frequency of unsolvable situation in this environment, including
MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
situation is that in the v0 version of these environments, the box containing
the key to the door connecting the upper-right room may be covered by the
blocking ball of the door connecting the lower-right room.
Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
unsolvable situation.
Expected probability of unsolvable situation:
- MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
- MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
- MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
- MiniGrid-ObstructedMaze-Full-v0: 0%
"""
for env_id in TESTING_ENVS:
print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
for env_id in TESTING_ENVS:
print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
if __name__ == "__main__":
main()