You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

45 lines
1.3 KiB

2 months ago
  1. """Finds all the specs that we can test with"""
  2. from __future__ import annotations
  3. import gymnasium as gym
  4. import numpy as np
  5. all_testing_env_specs = [
  6. env_spec
  7. for env_spec in gym.envs.registry.values()
  8. if (
  9. isinstance(env_spec.entry_point, str)
  10. and env_spec.entry_point.startswith("minigrid.envs")
  11. )
  12. ]
  13. minigrid_testing_env_specs = [
  14. env_spec
  15. for env_spec in all_testing_env_specs
  16. if not env_spec.entry_point.startswith("minigrid.envs.babyai")
  17. ]
  18. def assert_equals(a, b, prefix=None):
  19. """Assert equality of data structures `a` and `b`.
  20. Args:
  21. a: first data structure
  22. b: second data structure
  23. prefix: prefix for failed assertion message for types and dicts
  24. """
  25. assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
  26. if isinstance(a, dict):
  27. assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
  28. for k in a.keys():
  29. v_a = a[k]
  30. v_b = b[k]
  31. assert_equals(v_a, v_b)
  32. elif isinstance(a, np.ndarray):
  33. np.testing.assert_array_equal(a, b)
  34. elif isinstance(a, tuple):
  35. for elem_from_a, elem_from_b in zip(a, b):
  36. assert_equals(elem_from_a, elem_from_b)
  37. else:
  38. assert a == b