You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
3.3 KiB

2 months ago
  1. #!/usr/bin/env python3
  2. from __future__ import annotations
  3. import time
  4. import gymnasium as gym
  5. from minigrid.manual_control import ManualControl
  6. from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
  7. def benchmark(env_id, num_resets, num_frames):
  8. env = gym.make(env_id, render_mode="rgb_array")
  9. # Benchmark env.reset
  10. t0 = time.time()
  11. for i in range(num_resets):
  12. env.reset()
  13. t1 = time.time()
  14. dt = t1 - t0
  15. reset_time = (1000 * dt) / num_resets
  16. # Benchmark rendering
  17. t0 = time.time()
  18. for i in range(num_frames):
  19. env.render()
  20. t1 = time.time()
  21. dt = t1 - t0
  22. frames_per_sec = num_frames / dt
  23. # Create an environment with an RGB agent observation
  24. env = gym.make(env_id, render_mode="rgb_array")
  25. env = RGBImgPartialObsWrapper(env)
  26. env = ImgObsWrapper(env)
  27. env.reset()
  28. # Benchmark rendering in agent view
  29. t0 = time.time()
  30. for i in range(num_frames):
  31. obs, reward, terminated, truncated, info = env.step(0)
  32. t1 = time.time()
  33. dt = t1 - t0
  34. agent_view_fps = num_frames / dt
  35. print(f"Env reset time: {reset_time:.1f} ms")
  36. print(f"Rendering FPS : {frames_per_sec:.0f}")
  37. print(f"Agent view FPS: {agent_view_fps:.0f}")
  38. env.close()
  39. def benchmark_manual_control(env_id, num_resets, num_frames, tile_size):
  40. env = gym.make(env_id, tile_size=tile_size)
  41. env = ManualControl(env, seed=args.seed)
  42. # Benchmark env.reset
  43. t0 = time.time()
  44. for i in range(num_resets):
  45. env.reset()
  46. t1 = time.time()
  47. dt = t1 - t0
  48. reset_time = (1000 * dt) / num_resets
  49. # Benchmark rendering
  50. t0 = time.time()
  51. for i in range(num_frames):
  52. env.redraw()
  53. t1 = time.time()
  54. dt = t1 - t0
  55. frames_per_sec = num_frames / dt
  56. # Create an environment with an RGB agent observation
  57. env = gym.make(env_id, tile_size=tile_size)
  58. env = RGBImgPartialObsWrapper(env, env.tile_size)
  59. env = ImgObsWrapper(env)
  60. env = ManualControl(env, seed=args.seed)
  61. env.reset()
  62. # Benchmark rendering in agent view
  63. t0 = time.time()
  64. for i in range(num_frames):
  65. env.step(0)
  66. t1 = time.time()
  67. dt = t1 - t0
  68. agent_view_fps = num_frames / dt
  69. print(f"Env reset time: {reset_time:.1f} ms")
  70. print(f"Rendering FPS : {frames_per_sec:.0f}")
  71. print(f"Agent view FPS: {agent_view_fps:.0f}")
  72. env.close()
  73. if __name__ == "__main__":
  74. import argparse
  75. parser = argparse.ArgumentParser()
  76. parser.add_argument(
  77. "--env-id",
  78. dest="env_id",
  79. help="gym environment to load",
  80. default="MiniGrid-LavaGapS7-v0",
  81. )
  82. parser.add_argument(
  83. "--seed",
  84. type=int,
  85. help="random seed to generate the environment with",
  86. default=None,
  87. )
  88. parser.add_argument(
  89. "--num-resets",
  90. type=int,
  91. help="number of times to reset the environment for benchmarking",
  92. default=200,
  93. )
  94. parser.add_argument(
  95. "--num-frames",
  96. type=int,
  97. help="number of frames to test rendering for",
  98. default=5000,
  99. )
  100. parser.add_argument(
  101. "--tile-size", type=int, help="size at which to render tiles", default=32
  102. )
  103. args = parser.parse_args()
  104. benchmark(args.env_id, args.num_resets, args.num_frames)
  105. benchmark_manual_control(
  106. args.env_id, args.num_resets, args.num_frames, args.tile_size
  107. )