import gymnasium as gym
import numpy as np
import random

from minigrid.core.actions import Actions
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX

from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot

from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler


class OneHotShieldingWrapper(gym.core.ObservationWrapper):
    def __init__(self, env, vector_index, framestack):
        super().__init__(env)
        self.framestack = framestack
        # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
        # +4: Direction.
        self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
        self.init_x = None
        self.init_y = None
        self.x_positions = []
        self.y_positions = []
        self.x_y_delta_buffer = deque(maxlen=100)
        self.vector_index = vector_index
        self.frame_buffer = deque(maxlen=self.framestack)
        for _ in range(self.framestack):
            self.frame_buffer.append(np.zeros((self.single_frame_dim,)))

        self.observation_space = Dict(
            {
                "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
                "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
            }
            )

    def observation(self, obs):
        # Debug output: max-x/y positions to watch exploration progress.
        # print(F"Initial observation in Wrapper {obs}")
        if self.step_count == 0:
            for _ in range(self.framestack):
                self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
            if self.vector_index == 0:
                if self.x_positions:
                    max_diff = max(
                        np.sqrt(
                            (np.array(self.x_positions) - self.init_x) ** 2
                            + (np.array(self.y_positions) - self.init_y) ** 2
                        )
                    )
                    self.x_y_delta_buffer.append(max_diff)
                    print(
                        "100-average dist travelled={}".format(
                            np.mean(self.x_y_delta_buffer)
                        )
                    )
                    self.x_positions = []
                    self.y_positions = []
                self.init_x = self.agent_pos[0]
                self.init_y = self.agent_pos[1]


        self.x_positions.append(self.agent_pos[0])
        self.y_positions.append(self.agent_pos[1])

        image = obs["data"]
        # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
        objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
        colors = one_hot(image[:, :, 1], depth=len(COLORS))
        states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))

        all_ = np.concatenate([objects, colors, states], -1)
        all_flat = np.reshape(all_, (-1,))
        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
        single_frame = np.concatenate([all_flat, direction])
        self.frame_buffer.append(single_frame)

        tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
        return tmp


class MiniGridShieldingWrapper(gym.core.Wrapper):
    def __init__(self, 
                 env, 
                shield_creator : ShieldHandler,
                shield_query_creator,
                create_shield_at_reset=True,    
                mask_actions=True):
        super(MiniGridShieldingWrapper, self).__init__(env)
        self.max_available_actions = env.action_space.n
        self.observation_space = Dict(
            {
                "data": env.observation_space.spaces["image"],
                "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
            }
        )
        self.shield_creator = shield_creator
        self.create_shield_at_reset = create_shield_at_reset
        self.shield = shield_creator.create_shield(env=self.env)
        self.mask_actions = mask_actions
        self.shield_query_creator = shield_query_creator
        print(F"Shielding is {self.mask_actions}")

    def create_action_mask(self):
        # print(F"{self.mask_actions} No shielding")
        if not self.mask_actions:
            # print("No shielding")
            ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
            # print(ret)
            return ret
        
        cur_pos_str = self.shield_query_creator(self.env)
        # print(F"Pos string {cur_pos_str}")
        # print(F"Shield {list(self.shield.keys())[0]}")
        # print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}")
        # Create the mask
        # If shield restricts action mask only valid with 1.0
        # else set all actions as valid
        allowed_actions = []
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)

        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
            allowed_actions = self.shield[cur_pos_str]
            zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
            has_allowed_actions = False

            for allowed_action in allowed_actions:
                index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
                if index is None:
                    print(F"No mapping for action {list(allowed_action.labels)}")                   
                    print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}")                    
                    assert(False)
                
                allowed =  1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
                if allowed_action.prob == 0 and allowed:
                    assert False
                if allowed:
                    has_allowed_actions = True
                mask[index] = allowed               
            
            # if not has_allowed_actions:
            #     print(F"No action allowed for pos string {cur_pos_str}")
            #     assert(False)

        else:
            for index, x in enumerate(mask):
                mask[index] = 1.0
        
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])

        if front_tile is not None and front_tile.type == "key":
            mask[Actions.pickup] = 1.0
            
            
        if front_tile and front_tile.type == "door":
            mask[Actions.toggle] = 1.0
        # print(F"Mask is {mask} State: {cur_pos_str}")
        return mask

    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)
        
        if self.create_shield_at_reset and self.mask_actions:
            self.shield = self.shield_creator.create_shield(env=self.env)
        
        mask = self.create_action_mask()
        return {
            "data": obs["image"],
            "action_mask": mask
        }, infos

    def step(self, action):
        orig_obs, rew, done, truncated, info = self.env.step(action)

        mask = self.create_action_mask()
        obs = {
            "data": orig_obs["image"],
            "action_mask": mask,
        }

        return obs, rew, done, truncated, info



class MiniGridSbShieldingWrapper(gym.core.Wrapper):
    def __init__(self, 
                 env, 
                 shield_creator : ShieldHandler,
                 shield_query_creator,
                 create_shield_at_reset = True,
                 mask_actions=True,
                 ):
        super(MiniGridSbShieldingWrapper, self).__init__(env)
        self.max_available_actions = env.action_space.n
        self.observation_space = env.observation_space.spaces["image"]
        
        self.shield_creator = shield_creator
        self.mask_actions = mask_actions
        self.shield_query_creator = shield_query_creator
        print(F"Shielding is {self.mask_actions}")

    def create_action_mask(self):
        if not self.mask_actions:
            return  np.array([1.0] * self.max_available_actions, dtype=np.int8)
               
        cur_pos_str = self.shield_query_creator(self.env)
        
        allowed_actions = []

        # Create the mask
        # If shield restricts actions, mask only valid actions with 1.0
        # else set all actions valid
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)

        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
            allowed_actions = self.shield[cur_pos_str]
            for allowed_action in allowed_actions:
                index =  get_action_index_mapping(allowed_action.labels)
                if index is None:
                     assert(False)
                              
                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
        else:
            for index, x in enumerate(mask):
                mask[index] = 1.0
        
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])

            
        if front_tile and front_tile.type == "door":
            mask[Actions.toggle] = 1.0            
            
        return mask  
    

    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)
      
        shield = self.shield_creator.create_shield(env=self.env)
        
        self.shield = shield
        return obs["image"], infos

    def step(self, action):
        orig_obs, rew, done, truncated, info = self.env.step(action)
        obs = orig_obs["image"]
        
        return obs, rew, done, truncated, info