import gymnasium as gym
import numpy as np
import random
from moviepy.editor import ImageSequenceClip

from utils import MiniGridShieldHandler, common_parser
from stable_baselines3.common.callbacks import BaseCallback, CheckpointCallback
from stable_baselines3.common.logger import Image

class MiniGridSbShieldingWrapper(gym.core.Wrapper):
    def __init__(self,
                 env,
                 shield_handler : MiniGridShieldHandler,
                 create_shield_at_reset = False,
                 ):
        super().__init__(env)
        self.shield_handler = shield_handler
        self.create_shield_at_reset = create_shield_at_reset

        shield = self.shield_handler.create_shield(env=self.env)
        self.shield = shield

    def create_action_mask(self):
        try:
            return self.shield[self.env.get_symbolic_state()]
        except:
            return [0.0] * 3 + [0.0] * 4

    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)

        if self.create_shield_at_reset:
            shield = self.shield_handler.create_shield(env=self.env)
            self.shield = shield
        return obs, infos

    def step(self, action):
        obs, rew, done, truncated, info = self.env.step(action)
        info["no_shield_action"] = not self.shield.__contains__(self.env.get_symbolic_state())
        return obs, rew, done, truncated, info

def parse_sb3_arguments():
    parser = common_parser()
    args = parser.parse_args()

    return args

class ImageRecorderCallback(BaseCallback):
    def __init__(self, eval_env, render_freq, n_eval_episodes, evaluation_method, log_dir, deterministic=True, verbose=0):
        super().__init__(verbose)

        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic
        self._evaluation_method = evaluation_method
        self._log_dir = log_dir

    def _on_training_start(self):
        image = self.training_env.render(mode="rgb_array")
        self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))

    def _on_step(self) -> bool:
        #if self.n_calls % self._render_freq == 0:
        #    self.record_video()
        return True

    def _on_training_end(self) -> None:
        self.record_video()

    def record_video(self) -> bool:
        screens = []
        def grab_screens(_locals, _globals) -> None:
            """
            Renders the environment in its current state, recording the screen in the captured `screens` list

            :param _locals: A dictionary containing all local variables of the callback's scope
            :param _globals: A dictionary containing all global variables of the callback's scope
            """
            screen = self._eval_env.render()
            screens.append(screen)
        self._evaluation_method(
            self.model,
            self._eval_env,
            callback=grab_screens,
            n_eval_episodes=self._n_eval_episodes,
            deterministic=self._deterministic,
        )

        clip = ImageSequenceClip(list(screens), fps=3)
        clip.write_gif(f"{self._log_dir}/{self.n_calls}.gif", fps=3)
        return True


class InfoCallback(BaseCallback):
    """
    Custom callback for plotting additional values in tensorboard.
    """

    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.sum_goal = 0
        self.sum_lava = 0
        self.sum_collisions = 0
        self.sum_opened_door = 0
        self.sum_picked_up = 0
        self.no_shield_action = 0

    def _on_step(self) -> bool:
        infos = self.locals["infos"][0]
        if infos["reached_goal"]:
            self.sum_goal += 1
        if infos["ran_into_lava"]:
            self.sum_lava += 1
        self.logger.record("info/sum_reached_goal", self.sum_goal)
        self.logger.record("info/sum_ran_into_lava", self.sum_lava)
        if "collision" in infos:
            if infos["collision"]:
                self.sum_collisions += 1
            self.logger.record("info/sum_collision", self.sum_collisions)
        if "opened_door" in infos:
            if infos["opened_door"]:
                self.sum_opened_door += 1
            self.logger.record("info/sum_opened_door", self.sum_opened_door)
        if "picked_up" in infos:
            if infos["picked_up"]:
                self.sum_picked_up += 1
            self.logger.record("info/sum_picked_up", self.sum_picked_up)
        if "no_shield_action" in infos:
            if infos["no_shield_action"]:
                self.no_shield_action += 1
            self.logger.record("info/no_shield_action", self.no_shield_action)
        return True