You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
from __future__ import annotations
import gymnasium as gym import pytest
from minigrid.core.constants import COLOR_NAMES from minigrid.core.world_object import Ball, Box
TESTING_ENVS = [ "MiniGrid-ObstructedMaze-2Dlhb", "MiniGrid-ObstructedMaze-1Q", "MiniGrid-ObstructedMaze-2Q", "MiniGrid-ObstructedMaze-Full", ]
def find_ball_room(env): for obj in env.grid.grid: if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]: return env.room_from_pos(*obj.cur_pos)
def find_target_key(env, color): for obj in env.grid.grid: if isinstance(obj, Box) and obj.contains.color == color: return True return False
def env_test(env_id, repeats=10000): env = gym.make(env_id)
cnt = 0 for _ in range(repeats): env.reset() ball_room = find_ball_room(env) ball_room_doors = list(filter(None, ball_room.doors)) keys_exit = [find_target_key(env, door.color) for door in ball_room_doors] if not any(keys_exit): cnt += 1
return (cnt / repeats) * 100
@pytest.mark.parametrize("env_id", TESTING_ENVS) def test_solvable_env(env_id): assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
def main(): """
Test the frequency of unsolvable situation in this environment, including MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable situation is that in the v0 version of these environments, the box containing the key to the door connecting the upper-right room may be covered by the blocking ball of the door connecting the lower-right room.
Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an unsolvable situation.
Expected probability of unsolvable situation: - MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67% - MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67% - MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33% - MiniGrid-ObstructedMaze-Full-v0: 0% """
for env_id in TESTING_ENVS: print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.") for env_id in TESTING_ENVS: print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
if __name__ == "__main__": main()
|