You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

391 lines
11 KiB

2 months ago
  1. from __future__ import annotations
  2. import math
  3. import gymnasium as gym
  4. import numpy as np
  5. import pytest
  6. from minigrid.core.actions import Actions
  7. from minigrid.core.constants import OBJECT_TO_IDX
  8. from minigrid.envs import EmptyEnv
  9. from minigrid.wrappers import (
  10. ActionBonus,
  11. DictObservationSpaceWrapper,
  12. DirectionObsWrapper,
  13. FlatObsWrapper,
  14. FullyObsWrapper,
  15. ImgObsWrapper,
  16. NoDeath,
  17. OneHotPartialObsWrapper,
  18. PositionBonus,
  19. ReseedWrapper,
  20. RGBImgObsWrapper,
  21. RGBImgPartialObsWrapper,
  22. StochasticActionWrapper,
  23. SymbolicObsWrapper,
  24. ViewSizeWrapper,
  25. )
  26. from tests.utils import all_testing_env_specs, assert_equals, minigrid_testing_env_specs
  27. SEEDS = [100, 243, 500]
  28. NUM_STEPS = 100
  29. @pytest.mark.parametrize(
  30. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  31. )
  32. def test_reseed_wrapper(env_spec):
  33. """
  34. Test the ReseedWrapper with a list of SEEDS.
  35. """
  36. unwrapped_env = env_spec.make()
  37. env = env_spec.make()
  38. env = ReseedWrapper(env, seeds=SEEDS)
  39. env.action_space.seed(0)
  40. for seed in SEEDS:
  41. env.reset()
  42. unwrapped_env.reset(seed=seed)
  43. for time_step in range(NUM_STEPS):
  44. action = env.action_space.sample()
  45. obs, rew, terminated, truncated, info = env.step(action)
  46. (
  47. unwrapped_obs,
  48. unwrapped_rew,
  49. unwrapped_terminated,
  50. unwrapped_truncated,
  51. unwrapped_info,
  52. ) = unwrapped_env.step(action)
  53. assert_equals(obs, unwrapped_obs, f"[{time_step}] ")
  54. assert unwrapped_env.observation_space.contains(obs)
  55. assert (
  56. rew == unwrapped_rew
  57. ), f"[{time_step}] reward={rew}, unwrapped reward={unwrapped_rew}"
  58. assert (
  59. terminated == unwrapped_terminated
  60. ), f"[{time_step}] terminated={terminated}, unwrapped terminated={unwrapped_terminated}"
  61. assert (
  62. truncated == unwrapped_truncated
  63. ), f"[{time_step}] truncated={truncated}, unwrapped truncated={unwrapped_truncated}"
  64. assert_equals(info, unwrapped_info, f"[{time_step}] ")
  65. # Start the next seed
  66. if terminated or truncated:
  67. break
  68. env.close()
  69. unwrapped_env.close()
  70. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  71. def test_position_bonus_wrapper(env_id):
  72. env = gym.make(env_id)
  73. wrapped_env = PositionBonus(gym.make(env_id))
  74. action_forward = Actions.forward
  75. action_left = Actions.left
  76. action_right = Actions.right
  77. for _ in range(10):
  78. wrapped_env.reset()
  79. for _ in range(5):
  80. wrapped_env.step(action_forward)
  81. # Turn lef 3 times (check that actions don't influence bonus)
  82. for _ in range(3):
  83. _, wrapped_rew, _, _, _ = wrapped_env.step(action_left)
  84. env.reset()
  85. for _ in range(5):
  86. env.step(action_forward)
  87. # Turn right 3 times
  88. for _ in range(3):
  89. _, rew, _, _, _ = env.step(action_right)
  90. expected_bonus_reward = rew + 1 / math.sqrt(13)
  91. assert expected_bonus_reward == wrapped_rew
  92. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  93. def test_action_bonus_wrapper(env_id):
  94. env = gym.make(env_id)
  95. wrapped_env = ActionBonus(gym.make(env_id))
  96. action = Actions.forward
  97. for _ in range(10):
  98. wrapped_env.reset()
  99. for _ in range(5):
  100. _, wrapped_rew, _, _, _ = wrapped_env.step(action)
  101. env.reset()
  102. for _ in range(5):
  103. _, rew, _, _, _ = env.step(action)
  104. expected_bonus_reward = rew + 1 / math.sqrt(10)
  105. assert expected_bonus_reward == wrapped_rew
  106. @pytest.mark.parametrize(
  107. "env_spec",
  108. minigrid_testing_env_specs,
  109. ids=[spec.id for spec in minigrid_testing_env_specs],
  110. ) # DictObservationSpaceWrapper is not compatible with BabyAI levels. See minigrid/wrappers.py for more details.
  111. def test_dict_observation_space_wrapper(env_spec):
  112. env = env_spec.make()
  113. env = DictObservationSpaceWrapper(env)
  114. env.reset()
  115. mission = env.mission
  116. obs, _, _, _, _ = env.step(0)
  117. assert env.string_to_indices(mission) == [
  118. value for value in obs["mission"] if value != 0
  119. ]
  120. env.close()
  121. @pytest.mark.parametrize(
  122. "wrapper",
  123. [
  124. ReseedWrapper,
  125. ImgObsWrapper,
  126. FlatObsWrapper,
  127. ViewSizeWrapper,
  128. DictObservationSpaceWrapper,
  129. OneHotPartialObsWrapper,
  130. RGBImgPartialObsWrapper,
  131. FullyObsWrapper,
  132. ],
  133. )
  134. @pytest.mark.parametrize(
  135. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  136. )
  137. def test_main_wrappers(wrapper, env_spec):
  138. if (
  139. wrapper in (DictObservationSpaceWrapper, FlatObsWrapper)
  140. and env_spec not in minigrid_testing_env_specs
  141. ):
  142. # DictObservationSpaceWrapper and FlatObsWrapper are not compatible with BabyAI levels
  143. # See minigrid/wrappers.py for more details
  144. pytest.skip()
  145. env = env_spec.make()
  146. env = wrapper(env)
  147. for _ in range(10):
  148. env.reset()
  149. env.step(0)
  150. env.close()
  151. @pytest.mark.parametrize(
  152. "wrapper",
  153. [
  154. OneHotPartialObsWrapper,
  155. RGBImgPartialObsWrapper,
  156. FullyObsWrapper,
  157. ],
  158. )
  159. @pytest.mark.parametrize(
  160. "env_spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
  161. )
  162. def test_observation_space_wrappers(wrapper, env_spec):
  163. env = wrapper(env_spec.make(disable_env_checker=True))
  164. obs_space, wrapper_name = env.observation_space, wrapper.__name__
  165. assert isinstance(
  166. obs_space, gym.spaces.Dict
  167. ), f"Observation space for {wrapper_name} is not a Dict: {obs_space}."
  168. # This should not fail either
  169. ImgObsWrapper(env)
  170. env.reset()
  171. env.step(0)
  172. env.close()
  173. class EmptyEnvWithExtraObs(EmptyEnv):
  174. """
  175. Custom environment with an extra observation
  176. """
  177. def __init__(self) -> None:
  178. super().__init__(size=5)
  179. self.observation_space["size"] = gym.spaces.Box(
  180. low=0, high=np.iinfo(np.uint).max, shape=(2,), dtype=np.uint
  181. )
  182. def reset(self, **kwargs):
  183. obs, info = super().reset(**kwargs)
  184. obs["size"] = np.array([self.width, self.height])
  185. return obs, info
  186. def step(self, action):
  187. obs, reward, terminated, truncated, info = super().step(action)
  188. obs["size"] = np.array([self.width, self.height])
  189. return obs, reward, terminated, truncated, info
  190. @pytest.mark.parametrize(
  191. "wrapper",
  192. [
  193. OneHotPartialObsWrapper,
  194. RGBImgObsWrapper,
  195. RGBImgPartialObsWrapper,
  196. FullyObsWrapper,
  197. ],
  198. )
  199. def test_agent_sees_method(wrapper):
  200. env1 = wrapper(EmptyEnvWithExtraObs())
  201. env2 = wrapper(gym.make("MiniGrid-Empty-5x5-v0"))
  202. obs1, _ = env1.reset(seed=0)
  203. obs2, _ = env2.reset(seed=0)
  204. assert "size" in obs1
  205. assert obs1["size"].shape == (2,)
  206. assert (obs1["size"] == [5, 5]).all()
  207. for key in obs2:
  208. assert np.array_equal(obs1[key], obs2[key])
  209. obs1, reward1, terminated1, truncated1, _ = env1.step(0)
  210. obs2, reward2, terminated2, truncated2, _ = env2.step(0)
  211. assert "size" in obs1
  212. assert obs1["size"].shape == (2,)
  213. assert (obs1["size"] == [5, 5]).all()
  214. for key in obs2:
  215. assert np.array_equal(obs1[key], obs2[key])
  216. @pytest.mark.parametrize("view_size", [5, 7, 9])
  217. def test_viewsize_wrapper(view_size):
  218. env = gym.make("MiniGrid-Empty-5x5-v0")
  219. env = ViewSizeWrapper(env, agent_view_size=view_size)
  220. env.reset()
  221. obs, _, _, _, _ = env.step(0)
  222. assert obs["image"].shape == (view_size, view_size, 3)
  223. env.close()
  224. @pytest.mark.parametrize("env_id", ["MiniGrid-LavaCrossingS11N5-v0"])
  225. @pytest.mark.parametrize("type", ["slope", "angle"])
  226. def test_direction_obs_wrapper(env_id, type):
  227. env = gym.make(env_id)
  228. env = DirectionObsWrapper(env, type=type)
  229. obs, _ = env.reset()
  230. slope = np.divide(
  231. env.goal_position[1] - env.agent_pos[1],
  232. env.goal_position[0] - env.agent_pos[0],
  233. )
  234. if type == "slope":
  235. assert obs["goal_direction"] == slope
  236. elif type == "angle":
  237. assert obs["goal_direction"] == np.arctan(slope)
  238. obs, _, _, _, _ = env.step(0)
  239. slope = np.divide(
  240. env.goal_position[1] - env.agent_pos[1],
  241. env.goal_position[0] - env.agent_pos[0],
  242. )
  243. if type == "slope":
  244. assert obs["goal_direction"] == slope
  245. elif type == "angle":
  246. assert obs["goal_direction"] == np.arctan(slope)
  247. env.close()
  248. @pytest.mark.parametrize("env_id", ["MiniGrid-DistShift1-v0"])
  249. def test_symbolic_obs_wrapper(env_id):
  250. env = gym.make(env_id)
  251. env = SymbolicObsWrapper(env)
  252. obs, _ = env.reset(seed=123)
  253. agent_pos = env.agent_pos
  254. goal_pos = env.goal_pos
  255. assert obs["image"].shape == (env.width, env.height, 3)
  256. assert np.alltrue(
  257. obs["image"][agent_pos[0], agent_pos[1], :]
  258. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  259. )
  260. assert np.alltrue(
  261. obs["image"][goal_pos[0], goal_pos[1], :]
  262. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  263. )
  264. obs, _, _, _, _ = env.step(2)
  265. agent_pos = env.agent_pos
  266. goal_pos = env.goal_pos
  267. assert obs["image"].shape == (env.width, env.height, 3)
  268. assert np.alltrue(
  269. obs["image"][agent_pos[0], agent_pos[1], :]
  270. == np.array([agent_pos[0], agent_pos[1], OBJECT_TO_IDX["agent"]])
  271. )
  272. assert np.alltrue(
  273. obs["image"][goal_pos[0], goal_pos[1], :]
  274. == np.array([goal_pos[0], goal_pos[1], OBJECT_TO_IDX["goal"]])
  275. )
  276. env.close()
  277. @pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
  278. def test_stochastic_action_wrapper(env_id):
  279. env = gym.make(env_id)
  280. env = StochasticActionWrapper(env, prob=0.2)
  281. _, _ = env.reset()
  282. for _ in range(20):
  283. _, _, _, _, _ = env.step(0)
  284. env.close()
  285. env = gym.make(env_id)
  286. env = StochasticActionWrapper(env, prob=0.2, random_action=1)
  287. _, _ = env.reset()
  288. for _ in range(20):
  289. _, _, _, _, _ = env.step(0)
  290. env.close()
  291. def test_dict_observation_space_doesnt_clash_with_one_hot():
  292. env = gym.make("MiniGrid-Empty-5x5-v0")
  293. env = OneHotPartialObsWrapper(env)
  294. env = DictObservationSpaceWrapper(env)
  295. env.reset()
  296. obs, _, _, _, _ = env.step(0)
  297. assert obs["image"].shape == (7, 7, 20)
  298. assert env.observation_space["image"].shape == (7, 7, 20)
  299. env.close()
  300. def test_no_death_wrapper():
  301. death_cost = -1
  302. env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
  303. _, _ = env.reset(seed=2)
  304. _, _, _, _, _ = env.step(1)
  305. _, reward, term, *_ = env.step(2)
  306. env_wrap = NoDeath(env, ("lava",), death_cost)
  307. _, _ = env_wrap.reset(seed=2)
  308. _, _, _, _, _ = env_wrap.step(1)
  309. _, reward_wrap, term_wrap, *_ = env_wrap.step(2)
  310. assert term and not term_wrap
  311. assert reward_wrap == reward + death_cost
  312. env.close()
  313. env_wrap.close()
  314. env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
  315. _, _ = env.reset(seed=2)
  316. _, reward, term, *_ = env.step(2)
  317. env = NoDeath(env, ("ball",), death_cost)
  318. _, _ = env.reset(seed=2)
  319. _, reward_wrap, term_wrap, *_ = env.step(2)
  320. assert term and not term_wrap
  321. assert reward_wrap == reward + death_cost
  322. env.close()
  323. env_wrap.close()