You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.2 KiB

2 months ago
  1. from __future__ import annotations
  2. import gymnasium as gym
  3. import pytest
  4. from minigrid.core.constants import COLOR_NAMES
  5. from minigrid.core.world_object import Ball, Box
  6. TESTING_ENVS = [
  7. "MiniGrid-ObstructedMaze-2Dlhb",
  8. "MiniGrid-ObstructedMaze-1Q",
  9. "MiniGrid-ObstructedMaze-2Q",
  10. "MiniGrid-ObstructedMaze-Full",
  11. ]
  12. def find_ball_room(env):
  13. for obj in env.grid.grid:
  14. if isinstance(obj, Ball) and obj.color == COLOR_NAMES[0]:
  15. return env.room_from_pos(*obj.cur_pos)
  16. def find_target_key(env, color):
  17. for obj in env.grid.grid:
  18. if isinstance(obj, Box) and obj.contains.color == color:
  19. return True
  20. return False
  21. def env_test(env_id, repeats=10000):
  22. env = gym.make(env_id)
  23. cnt = 0
  24. for _ in range(repeats):
  25. env.reset()
  26. ball_room = find_ball_room(env)
  27. ball_room_doors = list(filter(None, ball_room.doors))
  28. keys_exit = [find_target_key(env, door.color) for door in ball_room_doors]
  29. if not any(keys_exit):
  30. cnt += 1
  31. return (cnt / repeats) * 100
  32. @pytest.mark.parametrize("env_id", TESTING_ENVS)
  33. def test_solvable_env(env_id):
  34. assert env_test(env_id + "-v1") == 0, f"{env_id} is unsolvable."
  35. def main():
  36. """
  37. Test the frequency of unsolvable situation in this environment, including
  38. MiniGrid-ObstructedMaze-2Dlhb, -1Q, -2Q, and -Full. The reason for the unsolvable
  39. situation is that in the v0 version of these environments, the box containing
  40. the key to the door connecting the upper-right room may be covered by the
  41. blocking ball of the door connecting the lower-right room.
  42. Note: Covering that occurs in MiniGrid-ObstructedMaze-Full won't lead to an
  43. unsolvable situation.
  44. Expected probability of unsolvable situation:
  45. - MiniGrid-ObstructedMaze-2Dlhb-v0: 1 / 15 = 6.67%
  46. - MiniGrid-ObstructedMaze-1Q-v0: 1/ 15 = 6.67%
  47. - MiniGrid-ObstructedMaze-2Q-v0: 1 / 30 = 3.33%
  48. - MiniGrid-ObstructedMaze-Full-v0: 0%
  49. """
  50. for env_id in TESTING_ENVS:
  51. print(f"{env_id}: {env_test(env_id + '-v0'):.2f}% unsolvable.")
  52. for env_id in TESTING_ENVS:
  53. print(f"{env_id}: {env_test(env_id + '-v1'):.2f}% unsolvable.")
  54. if __name__ == "__main__":
  55. main()