from __future__ import annotations

from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
from minigrid.minigrid_env import MiniGridEnv, is_slippery
from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject

import numpy as np

class AdversaryEnv(MiniGridEnv):

    """
    ## Description

    """

    def __init__(self, width=7, height=6, generate_wall=True, generate_lava=False, generate_slippery=False ,max_steps: int | None = None, **kwargs):
        if max_steps is None:
            max_steps = 10 * (width * height)**2
        mission_space = MissionSpace(mission_func=self._gen_mission)
        self.collision_penalty = -1
        super().__init__(
            mission_space=mission_space, width=width, height=height, max_steps=max_steps, **kwargs
        )

    @staticmethod
    def _gen_mission():
        return "Finish your task while avoiding the adversaries"

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)


    def step(self, action):
        delete_list = list()
        for position, box in self.background_tiles.items():
            if self.grid.get(*position) is None:
                self.grid.set(*position, box)
                self.grid.set_background(*position, None)
                delete_list.append(tuple(position))
        for position in delete_list:
            del self.background_tiles[position]

        obs, reward, terminated, truncated, info = super().step(action)

        agent_pos = self.agent_pos
        adv_penalty = 0

        if not terminated:
            for adversary in self.adversaries.values():
                collided = self.move_adversary(adversary, agent_pos)
                self.trajectory.append((adversary.color, adversary.adversary_pos, adversary.adversary_dir))
                if collided:
                    terminated = True
                    info["collision"] = True
                    try:
                        reward = self.collision_penalty
                    except e:
                        reward = -1


        return obs, reward, terminated, truncated, info

    def move_adversary(self, adversary, agent_pos):
        # fetch current location and forward location
        cur_pos = adversary.adversary_pos
        current_cell = self.grid.get(*adversary.adversary_pos)
        fwd_pos = cur_pos + adversary.dir_vec()
        fwd_cell = self.grid.get(*fwd_pos)
        collision = False
        need_position_update = False

        action = adversary.get_action(self)
        if action == self.actions.forward and is_slippery(current_cell):
            probabilities = current_cell.get_probabilities(adversary.adversary_dir)
            possible_fwd_pos, prob = self.get_neighbours_prob(adversary.adversary_pos, probabilities)
            fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
            fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
            fwd_cell = self.grid.get(*fwd_pos)
            need_position_update = True

        if action == self.actions.left:
            adversary.adversary_dir -= 1
            if adversary.adversary_dir < 0:
                adversary.adversary_dir += 4

        # Rotate right
        elif action == self.actions.right:
            adversary.adversary_dir = (adversary.adversary_dir + 1) % 4

        # Move forward
        elif action == self.actions.forward:
            if fwd_pos[0] == agent_pos[0] and fwd_pos[1] == agent_pos[1]:
                collision = True
            if fwd_cell is None or fwd_cell.can_overlap():
                adversary.adversary_pos = tuple(fwd_pos)

        # Pick up an object
        elif action == self.actions.pickup:
            if fwd_cell and fwd_cell.can_pickup():
                if adversary.carrying is None:
                    adversary.carrying = fwd_cell
                    adversary.carrying.cur_pos = np.array([-1, -1])
                    self.grid.set(fwd_pos[0], fwd_pos[1], None)

        # Drop an object
        elif action == self.actions.drop:
            if not fwd_cell and adversary.carrying:
                self.grid.set(fwd_pos[0], fwd_pos[1], adversary.carrying)
                adversary.carrying.cur_pos = fwd_pos
                adversary.carrying = None

        # Toggle/activate an object
        elif action == self.actions.toggle:
            if fwd_cell:
                fwd_cell.toggle(self, fwd_pos)

        # Done action (not used by default)
        elif action == self.actions.done:
            pass

        else:
            raise ValueError(f"Unknown action: {action}")

        if need_position_update and (fwd_cell is None or fwd_cell.can_overlap()):
            adversary.adversary_pos = tuple(fwd_pos)

        return collision