#!/usr/bin/env python3

from __future__ import annotations

import time

import gymnasium as gym

from minigrid.manual_control import ManualControl
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper


def benchmark(env_id, num_resets, num_frames):
    env = gym.make(env_id, render_mode="rgb_array")
    # Benchmark env.reset
    t0 = time.time()
    for i in range(num_resets):
        env.reset()
    t1 = time.time()
    dt = t1 - t0
    reset_time = (1000 * dt) / num_resets

    # Benchmark rendering
    t0 = time.time()
    for i in range(num_frames):
        env.render()
    t1 = time.time()
    dt = t1 - t0
    frames_per_sec = num_frames / dt

    # Create an environment with an RGB agent observation
    env = gym.make(env_id, render_mode="rgb_array")
    env = RGBImgPartialObsWrapper(env)
    env = ImgObsWrapper(env)

    env.reset()
    # Benchmark rendering in agent view
    t0 = time.time()
    for i in range(num_frames):
        obs, reward, terminated, truncated, info = env.step(0)
    t1 = time.time()
    dt = t1 - t0
    agent_view_fps = num_frames / dt

    print(f"Env reset time: {reset_time:.1f} ms")
    print(f"Rendering FPS : {frames_per_sec:.0f}")
    print(f"Agent view FPS: {agent_view_fps:.0f}")

    env.close()


def benchmark_manual_control(env_id, num_resets, num_frames, tile_size):
    env = gym.make(env_id, tile_size=tile_size)
    env = ManualControl(env, seed=args.seed)

    # Benchmark env.reset
    t0 = time.time()
    for i in range(num_resets):
        env.reset()
    t1 = time.time()
    dt = t1 - t0
    reset_time = (1000 * dt) / num_resets

    # Benchmark rendering
    t0 = time.time()
    for i in range(num_frames):
        env.redraw()
    t1 = time.time()
    dt = t1 - t0
    frames_per_sec = num_frames / dt

    # Create an environment with an RGB agent observation
    env = gym.make(env_id, tile_size=tile_size)
    env = RGBImgPartialObsWrapper(env, env.tile_size)
    env = ImgObsWrapper(env)

    env = ManualControl(env, seed=args.seed)
    env.reset()

    # Benchmark rendering in agent view
    t0 = time.time()
    for i in range(num_frames):
        env.step(0)
    t1 = time.time()
    dt = t1 - t0
    agent_view_fps = num_frames / dt

    print(f"Env reset time: {reset_time:.1f} ms")
    print(f"Rendering FPS : {frames_per_sec:.0f}")
    print(f"Agent view FPS: {agent_view_fps:.0f}")

    env.close()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--env-id",
        dest="env_id",
        help="gym environment to load",
        default="MiniGrid-LavaGapS7-v0",
    )
    parser.add_argument(
        "--seed",
        type=int,
        help="random seed to generate the environment with",
        default=None,
    )
    parser.add_argument(
        "--num-resets",
        type=int,
        help="number of times to reset the environment for benchmarking",
        default=200,
    )
    parser.add_argument(
        "--num-frames",
        type=int,
        help="number of frames to test rendering for",
        default=5000,
    )
    parser.add_argument(
        "--tile-size", type=int, help="size at which to render tiles", default=32
    )

    args = parser.parse_args()
    benchmark(args.env_id, args.num_resets, args.num_frames)

    benchmark_manual_control(
        args.env_id, args.num_resets, args.num_frames, args.tile_size
    )