import gymnasium as gym
import numpy as np


from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot

class OneHotWrapper(gym.core.ObservationWrapper):
    def __init__(self, env, vector_index, framestack):
        super().__init__(env)
        self.framestack = framestack
        # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
        # +4: Direction.
        self.single_frame_dim = 49 * (11 + 6 + 3) + 4
        self.init_x = None
        self.init_y = None
        self.x_positions = []
        self.y_positions = []
        self.x_y_delta_buffer = deque(maxlen=100)
        self.vector_index = vector_index
        self.frame_buffer = deque(maxlen=self.framestack)
        for _ in range(self.framestack):
            self.frame_buffer.append(np.zeros((self.single_frame_dim,)))

        self.observation_space = Dict(
            {
                "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
                "avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
            }
            ) 
        
        
        print(F"Set obersvation space to {self.observation_space}")
        

    def observation(self, obs):
        # Debug output: max-x/y positions to watch exploration progress.
        # print(F"Initial observation in Wrapper {obs}")
        if self.step_count == 0:
            for _ in range(self.framestack):
                self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
            if self.vector_index == 0:
                if self.x_positions:
                    max_diff = max(
                        np.sqrt(
                            (np.array(self.x_positions) - self.init_x) ** 2
                            + (np.array(self.y_positions) - self.init_y) ** 2
                        )
                    )
                    self.x_y_delta_buffer.append(max_diff)
                    print(
                        "100-average dist travelled={}".format(
                            np.mean(self.x_y_delta_buffer)
                        )
                    )
                    self.x_positions = []
                    self.y_positions = []
                self.init_x = self.agent_pos[0]
                self.init_y = self.agent_pos[1]

      
        self.x_positions.append(self.agent_pos[0])
        self.y_positions.append(self.agent_pos[1])
        
        image = obs["data"]

        # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
        objects = one_hot(image[:, :, 0], depth=11)
        colors = one_hot(image[:, :, 1], depth=6)
        states = one_hot(image[:, :, 2], depth=3)
      
        all_ = np.concatenate([objects, colors, states], -1)
        all_flat = np.reshape(all_, (-1,))
        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
        single_frame = np.concatenate([all_flat, direction])
        self.frame_buffer.append(single_frame)
        
        #obs["one-hot"] = np.concatenate(self.frame_buffer)
        tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] }
        return tmp#np.concatenate(self.frame_buffer)


class MiniGridEnvWrapper(gym.core.Wrapper):
    def __init__(self, env, shield):
        super(MiniGridEnvWrapper, self).__init__(env)
        self.max_available_actions = env.action_space.n
        self.observation_space = Dict(
            {
                "data": env.observation_space.spaces["image"],
                "avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
            }
        )
        
        self.shield = shield
        
        
    def create_action_mask(self):
        coordinates = self.env.agent_pos
        view_direction = self.env.agent_dir
        print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
        cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
        
        allowed_actions = []
        
        
        # Create the mask
        # If shield restricts action mask only valid with 1.0
        # else set everything to one
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
        
        # if cur_pos_str in self.shield:
        #     allowed_actions = self.shield[cur_pos_str]
        #     for allowed_action in allowed_actions:
        #         index = allowed_action[0]
        #         mask[index] = 1.0
        # else:
        #     for index in len(mask):
        #         mask[index] = 1.0
            
            
        print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}")
        mask[0] = 1.0
        return mask
    
    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset()
        return {
            "data": obs["image"],
            "avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8)
        }, infos
    
    def step(self, action):
        print(F"Performed action in step: {action}")
        orig_obs, rew, done, truncated, info = self.env.step(action)
      
        actions = self.create_action_mask()
        #print(F"Original observation is {orig_obs}")
        obs = {
            "data": orig_obs["image"],
            "avail_actions": actions,
        }
        
        #print(F"Info is {info}")
        return obs, rew, done, truncated, info
    
    


class ImgObsWrapper(gym.core.ObservationWrapper):
    """
    Use the image as the only observation output, no language/mission.

    Example:
        >>> import gymnasium as gym
        >>> from minigrid.wrappers import ImgObsWrapper
        >>> env = gym.make("MiniGrid-Empty-5x5-v0")
        >>> obs, _ = env.reset()
        >>> obs.keys()
        dict_keys(['image', 'direction', 'mission'])
        >>> env = ImgObsWrapper(env)
        >>> obs, _ = env.reset()
        >>> obs.shape
        (7, 7, 3)
    """

    def __init__(self, env):
        """A wrapper that makes image the only observation.

        Args:
            env: The environment to apply the wrapper
        """
        super().__init__(env)
        self.observation_space = env.observation_space.spaces["image"]
        print(F"Set obersvation space to {self.observation_space}")

    def observation(self, obs):
        #print(F"obs in img obs wrapper {obs}")
        tmp = {"data": obs["image"], "Test": obs["Test"]}
        
        return tmp