The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

351 lines
11 KiB

4 months ago
  1. from __future__ import annotations
  2. import pickle
  3. import re
  4. import warnings
  5. import gymnasium as gym
  6. import numpy as np
  7. import pytest
  8. from gymnasium.envs.registration import EnvSpec
  9. from gymnasium.utils.env_checker import check_env, data_equivalence
  10. from minigrid.core.grid import Grid
  11. from minigrid.core.mission import MissionSpace
  12. from tests.utils import all_testing_env_specs, assert_equals
  13. CHECK_ENV_IGNORE_WARNINGS = [
  14. f"\x1b[33mWARN: {message}\x1b[0m"
  15. for message in [
  16. "A Box observation space minimum value is -infinity. This is probably too low.",
  17. "A Box observation space maximum value is -infinity. This is probably too high.",
  18. "For Box action spaces, we recommend using a symmetric and normalized space (range=[-1, 1] or [0, 1]). See https://stable-baselines3.readthedocs.io/en/master/guide/rl_tips.html for more information.",
  19. ]
  20. ]
  21. @pytest.mark.parametrize(
  22. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  23. )
  24. def test_env(spec):
  25. # Capture warnings
  26. env = spec.make(disable_env_checker=True).unwrapped
  27. warnings.simplefilter("always")
  28. # Test if env adheres to Gym API
  29. with warnings.catch_warnings(record=True) as w:
  30. check_env(env)
  31. for warning in w:
  32. if warning.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
  33. raise gym.error.Error(f"Unexpected warning: {warning.message}")
  34. # Note that this precludes running this test in multiple threads.
  35. # However, we probably already can't do multithreading due to some environments.
  36. SEED = 0
  37. NUM_STEPS = 50
  38. @pytest.mark.parametrize(
  39. "env_spec", all_testing_env_specs, ids=[env.id for env in all_testing_env_specs]
  40. )
  41. def test_env_determinism_rollout(env_spec: EnvSpec):
  42. """Run a rollout with two environments and assert equality.
  43. This test run a rollout of NUM_STEPS steps with two environments
  44. initialized with the same seed and assert that:
  45. - observation after first reset are the same
  46. - same actions are sampled by the two envs
  47. - observations are contained in the observation space
  48. - obs, rew, terminated, truncated and info are equals between the two envs
  49. """
  50. # Don't check rollout equality if it's a nondeterministic environment.
  51. if env_spec.nondeterministic is True:
  52. return
  53. env_1 = env_spec.make(disable_env_checker=True)
  54. env_2 = env_spec.make(disable_env_checker=True)
  55. initial_obs_1 = env_1.reset(seed=SEED)
  56. initial_obs_2 = env_2.reset(seed=SEED)
  57. assert_equals(initial_obs_1, initial_obs_2)
  58. env_1.action_space.seed(SEED)
  59. for time_step in range(NUM_STEPS):
  60. # We don't evaluate the determinism of actions
  61. action = env_1.action_space.sample()
  62. obs_1, rew_1, terminated_1, truncated_1, info_1 = env_1.step(action)
  63. obs_2, rew_2, terminated_2, truncated_2, info_2 = env_2.step(action)
  64. assert_equals(obs_1, obs_2, f"[{time_step}] ")
  65. assert env_1.observation_space.contains(
  66. obs_1
  67. ) # obs_2 verified by previous assertion
  68. assert rew_1 == rew_2, f"[{time_step}] reward 1={rew_1}, reward 2={rew_2}"
  69. assert (
  70. terminated_1 == terminated_2
  71. ), f"[{time_step}] terminated 1={terminated_1}, terminated 2={terminated_2}"
  72. assert (
  73. truncated_1 == truncated_2
  74. ), f"[{time_step}] truncated 1={truncated_1}, truncated 2={truncated_2}"
  75. assert_equals(info_1, info_2, f"[{time_step}] ")
  76. if (
  77. terminated_1 or truncated_1
  78. ): # terminated_2 and truncated_2 verified by previous assertion
  79. env_1.reset(seed=SEED)
  80. env_2.reset(seed=SEED)
  81. env_1.close()
  82. env_2.close()
  83. @pytest.mark.parametrize(
  84. "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  85. )
  86. def test_render_modes(spec):
  87. env = spec.make()
  88. for mode in env.metadata.get("render_modes", []):
  89. if mode != "human":
  90. new_env = spec.make(render_mode=mode)
  91. new_env.reset()
  92. new_env.step(new_env.action_space.sample())
  93. new_env.render()
  94. @pytest.mark.parametrize("env_id", ["MiniGrid-DoorKey-6x6-v0"])
  95. def test_agent_sees_method(env_id):
  96. env = gym.make(env_id)
  97. goal_pos = (env.grid.width - 2, env.grid.height - 2)
  98. # Test the env.agent_sees() function
  99. env.reset()
  100. # Test the "in" operator on grid objects
  101. assert ("green", "goal") in env.grid
  102. assert ("blue", "key") not in env.grid
  103. for i in range(0, 500):
  104. action = env.action_space.sample()
  105. obs, reward, terminated, truncated, info = env.step(action)
  106. grid, _ = Grid.decode(obs["image"])
  107. goal_visible = ("green", "goal") in grid
  108. agent_sees_goal = env.agent_sees(*goal_pos)
  109. assert agent_sees_goal == goal_visible
  110. if terminated or truncated:
  111. env.reset()
  112. env.close()
  113. @pytest.mark.parametrize(
  114. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  115. )
  116. def test_max_steps_argument(env_spec):
  117. """
  118. Test that when initializing an environment with a fixed number of steps per episode (`max_steps` argument),
  119. the episode will be truncated after taking that number of steps.
  120. """
  121. max_steps = 50
  122. env = env_spec.make(max_steps=max_steps)
  123. env.reset()
  124. step_count = 0
  125. while True:
  126. _, _, terminated, truncated, _ = env.step(4)
  127. step_count += 1
  128. if truncated:
  129. assert step_count == max_steps
  130. step_count = 0
  131. break
  132. env.close()
  133. @pytest.mark.parametrize(
  134. "env_spec",
  135. all_testing_env_specs,
  136. ids=[spec.id for spec in all_testing_env_specs],
  137. )
  138. def test_pickle_env(env_spec):
  139. """Test that all environments are picklable."""
  140. env: gym.Env = env_spec.make()
  141. pickled_env: gym.Env = pickle.loads(pickle.dumps(env))
  142. data_equivalence(env.reset(), pickled_env.reset())
  143. action = env.action_space.sample()
  144. data_equivalence(env.step(action), pickled_env.step(action))
  145. env.close()
  146. pickled_env.close()
  147. @pytest.mark.parametrize(
  148. "env_spec",
  149. all_testing_env_specs,
  150. ids=[spec.id for spec in all_testing_env_specs],
  151. )
  152. def old_run_test(env_spec):
  153. # Load the gym environment
  154. env = env_spec.make()
  155. env.max_steps = min(env.max_steps, 200)
  156. env.reset()
  157. env.render()
  158. # Verify that the same seed always produces the same environment
  159. for i in range(0, 5):
  160. seed = 1337 + i
  161. _ = env.reset(seed=seed)
  162. grid1 = env.grid
  163. _ = env.reset(seed=seed)
  164. grid2 = env.grid
  165. assert grid1 == grid2
  166. env.reset()
  167. # Run for a few episodes
  168. num_episodes = 0
  169. while num_episodes < 5:
  170. # Pick a random action
  171. action = env.action_space.sample()
  172. obs, reward, terminated, truncated, info = env.step(action)
  173. # Validate the agent position
  174. assert env.agent_pos[0] < env.width
  175. assert env.agent_pos[1] < env.height
  176. # Test observation encode/decode roundtrip
  177. img = obs["image"]
  178. grid, vis_mask = Grid.decode(img)
  179. img2 = grid.encode(vis_mask=vis_mask)
  180. assert np.array_equal(img, img2)
  181. # Test the env to string function
  182. str(env)
  183. # Check that the reward is within the specified range
  184. assert reward >= env.reward_range[0], reward
  185. assert reward <= env.reward_range[1], reward
  186. if terminated or truncated:
  187. num_episodes += 1
  188. env.reset()
  189. env.render()
  190. # Test the close method
  191. env.close()
  192. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-8x8-v0"])
  193. def test_interactive_mode(env_id):
  194. env = gym.make(env_id)
  195. env.reset()
  196. for i in range(0, 100):
  197. print(f"step {i}")
  198. # Pick a random action
  199. action = env.action_space.sample()
  200. obs, reward, terminated, truncated, info = env.step(action)
  201. # Test the close method
  202. env.close()
  203. def test_mission_space():
  204. # Test placeholders
  205. mission_space = MissionSpace(
  206. mission_func=lambda color, obj_type: f"Get the {color} {obj_type}.",
  207. ordered_placeholders=[["green", "red"], ["ball", "key"]],
  208. )
  209. assert mission_space.contains("Get the green ball.")
  210. assert mission_space.contains("Get the red key.")
  211. assert not mission_space.contains("Get the purple box.")
  212. # Test passing inverted placeholders
  213. assert not mission_space.contains("Get the key red.")
  214. # Test passing extra repeated placeholders
  215. assert not mission_space.contains("Get the key red key.")
  216. # Test contained placeholders like "get the" and "go get the". "get the" string is contained in both placeholders.
  217. mission_space = MissionSpace(
  218. mission_func=lambda get_syntax, obj_type: f"{get_syntax} {obj_type}.",
  219. ordered_placeholders=[
  220. ["go get the", "get the", "go fetch the", "fetch the"],
  221. ["ball", "key"],
  222. ],
  223. )
  224. assert mission_space.contains("get the ball.")
  225. assert mission_space.contains("go get the key.")
  226. assert mission_space.contains("go fetch the ball.")
  227. # Test repeated placeholders
  228. mission_space = MissionSpace(
  229. mission_func=lambda get_syntax, color_1, obj_type_1, color_2, obj_type_2: f"{get_syntax} {color_1} {obj_type_1} and the {color_2} {obj_type_2}.",
  230. ordered_placeholders=[
  231. ["go get the", "get the", "go fetch the", "fetch the"],
  232. ["green", "red"],
  233. ["ball", "key"],
  234. ["green", "red"],
  235. ["ball", "key"],
  236. ],
  237. )
  238. assert mission_space.contains("get the green key and the green key.")
  239. assert mission_space.contains("go fetch the red ball and the green key.")
  240. # not reasonable to test for all environments, test for a few of them.
  241. @pytest.mark.parametrize(
  242. "env_id",
  243. [
  244. "MiniGrid-Empty-8x8-v0",
  245. "MiniGrid-DoorKey-16x16-v0",
  246. "MiniGrid-ObstructedMaze-1Dl-v0",
  247. ],
  248. )
  249. def test_env_sync_vectorization(env_id):
  250. def env_maker(env_id, **kwargs):
  251. def env_func():
  252. env = gym.make(env_id, **kwargs)
  253. return env
  254. return env_func
  255. num_envs = 4
  256. env = gym.vector.SyncVectorEnv([env_maker(env_id) for _ in range(num_envs)])
  257. env.reset()
  258. env.step(env.action_space.sample())
  259. env.close()
  260. def test_pprint_grid(env_id="MiniGrid-Empty-8x8-v0"):
  261. env = gym.make(env_id)
  262. env_repr = str(env)
  263. assert (
  264. env_repr
  265. == "<OrderEnforcing<PassiveEnvChecker<EmptyEnv<MiniGrid-Empty-8x8-v0>>>>"
  266. )
  267. with pytest.raises(
  268. ValueError,
  269. match=re.escape(
  270. "The environment hasn't been `reset` therefore the `agent_pos`, `agent_dir` or `grid` are unknown."
  271. ),
  272. ):
  273. env.unwrapped.pprint_grid()
  274. env.reset()
  275. assert isinstance(env.unwrapped.pprint_grid(), str)