You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

132 lines
3.3 KiB

#!/usr/bin/env python3
from __future__ import annotations
import time
import gymnasium as gym
from minigrid.manual_control import ManualControl
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
def benchmark(env_id, num_resets, num_frames):
env = gym.make(env_id, render_mode="rgb_array")
# Benchmark env.reset
t0 = time.time()
for i in range(num_resets):
env.reset()
t1 = time.time()
dt = t1 - t0
reset_time = (1000 * dt) / num_resets
# Benchmark rendering
t0 = time.time()
for i in range(num_frames):
env.render()
t1 = time.time()
dt = t1 - t0
frames_per_sec = num_frames / dt
# Create an environment with an RGB agent observation
env = gym.make(env_id, render_mode="rgb_array")
env = RGBImgPartialObsWrapper(env)
env = ImgObsWrapper(env)
env.reset()
# Benchmark rendering in agent view
t0 = time.time()
for i in range(num_frames):
obs, reward, terminated, truncated, info = env.step(0)
t1 = time.time()
dt = t1 - t0
agent_view_fps = num_frames / dt
print(f"Env reset time: {reset_time:.1f} ms")
print(f"Rendering FPS : {frames_per_sec:.0f}")
print(f"Agent view FPS: {agent_view_fps:.0f}")
env.close()
def benchmark_manual_control(env_id, num_resets, num_frames, tile_size):
env = gym.make(env_id, tile_size=tile_size)
env = ManualControl(env, seed=args.seed)
# Benchmark env.reset
t0 = time.time()
for i in range(num_resets):
env.reset()
t1 = time.time()
dt = t1 - t0
reset_time = (1000 * dt) / num_resets
# Benchmark rendering
t0 = time.time()
for i in range(num_frames):
env.redraw()
t1 = time.time()
dt = t1 - t0
frames_per_sec = num_frames / dt
# Create an environment with an RGB agent observation
env = gym.make(env_id, tile_size=tile_size)
env = RGBImgPartialObsWrapper(env, env.tile_size)
env = ImgObsWrapper(env)
env = ManualControl(env, seed=args.seed)
env.reset()
# Benchmark rendering in agent view
t0 = time.time()
for i in range(num_frames):
env.step(0)
t1 = time.time()
dt = t1 - t0
agent_view_fps = num_frames / dt
print(f"Env reset time: {reset_time:.1f} ms")
print(f"Rendering FPS : {frames_per_sec:.0f}")
print(f"Agent view FPS: {agent_view_fps:.0f}")
env.close()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--env-id",
dest="env_id",
help="gym environment to load",
default="MiniGrid-LavaGapS7-v0",
)
parser.add_argument(
"--seed",
type=int,
help="random seed to generate the environment with",
default=None,
)
parser.add_argument(
"--num-resets",
type=int,
help="number of times to reset the environment for benchmarking",
default=200,
)
parser.add_argument(
"--num-frames",
type=int,
help="number of frames to test rendering for",
default=5000,
)
parser.add_argument(
"--tile-size", type=int, help="size at which to render tiles", default=32
)
args = parser.parse_args()
benchmark(args.env_id, args.num_resets, args.num_frames)
benchmark_manual_control(
args.env_id, args.num_resets, args.num_frames, args.tile_size
)