import time
from collections import deque
from typing import Dict, Tuple

import gymnasium as gym
import numpy as np
import torch
from torch import Tensor

from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log

from sf_examples.atari.train_atari import parse_atari_args, register_atari_components

class SampleFactoryNNQueryWrapper:
    def setup(self):
        register_atari_components()
        cfg = parse_atari_args()
        actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO
        actor_critic.eval()

        device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda")
        actor_critic.model_to_device(device)

        policy_id = 0 #cfg.policy_index
        #name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
        name_prefix = "best"
        checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
        checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...)
        actor_critic.load_state_dict(checkpoint_dict["model"])

        rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device)

        self.rnn_states = rnn_states
        self.actor_critic = actor_critic

    def __init__(self):
        self.setup()


    def query(self, obs):
        with torch.no_grad():
            normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs)
            policy_outputs = self.actor_critic(normalized_obs, self.rnn_states)

            # sample actions from the distribution by default
            actions = policy_outputs["actions"]

            action_distribution = self.actor_critic.action_distribution()
            actions = argmax_actions(action_distribution)

            if actions.ndim == 1:
                actions = unsqueeze_tensor(actions, dim=-1)

            rnn_states = policy_outputs["new_rnn_states"]
            return actions[0][0].item()