You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.2 KiB
75 lines
2.2 KiB
from __future__ import annotations
|
|
|
|
import gymnasium as gym
|
|
import pytest
|
|
|
|
from minigrid.core.constants import COLOR_NAMES
|
|
from minigrid.core.world_object import Ball, Box
|
|
|
|
TESTING_ENVS = [
|
|
"MiniGrid-ObstructedMaze-2Dlhb",
|
|
"MiniGrid-ObstructedMaze-1Q",
|
|
"MiniGrid-ObstructedMaze-2Q",
|
|
"MiniGrid-ObstructedMaze-Full",
|
|
]
|
|
|
|
|
|
def find_ball_room(env):
|
|
for obj in env.grid.grid:
|
|
if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
|
|
return env.room_from_pos(*obj.cur_pos)
|
|
|
|
|
|
def find_target_key(env, color):
|
|
for obj in env.grid.grid:
|
|
if isinstance(obj, Box) and obj.contains.color == color:
|
|
return True
|
|
return False
|
|
|
|
|
|
def env_test(env_id, repeats=10000):
|
|
env = gym.make(env_id)
|
|
|
|
cnt = 0
|
|
for _ in range(repeats):
|
|
env.reset()
|
|
ball_room = find_ball_room(env)
|
|
ball_room_doors = list(filter(None, ball_room.doors))
|
|
keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
|
|
if not any(keys_exit):
|
|
cnt += 1
|
|
|
|
return (cnt / repeats) * 100
|
|
|
|
|
|
@pytest.mark.parametrize("env_id", TESTING_ENVS)
|
|
def test_solvable_env(env_id):
|
|
assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
|
|
|
|
|
|
def main():
|
|
"""
|
|
Test the frequency of unsolvable situation in this environment, including
|
|
MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
|
|
situation is that in the v0 version of these environments, the box containing
|
|
the key to the door connecting the upper-right room may be covered by the
|
|
blocking ball of the door connecting the lower-right room.
|
|
|
|
Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
|
|
unsolvable situation.
|
|
|
|
Expected probability of unsolvable situation:
|
|
- MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
|
|
- MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
|
|
- MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
|
|
- MiniGrid-ObstructedMaze-Full-v0: 0%
|
|
"""
|
|
|
|
for env_id in TESTING_ENVS:
|
|
print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
|
|
for env_id in TESTING_ENVS:
|
|
print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|