from dataclasses import dataclass, field import itertools from typing import Optional from minigrid.core.constants import ( COLOR_NAMES, IDX_TO_COLOR ) @dataclass(frozen=True, eq=True) class KeyState: color: str = "" col: int = 0 row: int = 0 @dataclass(frozen=True, eq=True) class BallState: color: str = "" col: int = 0 row: int = 0 @dataclass(frozen=True, eq=True) class BoxState: color: str = "" col: int = 0 row: int = 0 @dataclass(frozen=True, eq=True) class DoorState: color: str = "" locked: bool = True @dataclass(frozen=True, eq=True) class AdversaryState: color: str = "" col: int = 0 row: int = 0 view: Optional[int] = None carrying: str = "" def csv(self,delim=",") -> str: if self.view is not None: return f"{self.col}{delim}{self.row}{delim}{self.view}" else: return f"{self.col}{delim}{self.row}" @dataclass(frozen=True, eq=True) class State: colAgent: int rowAgent: int carrying: Optional[str] = None viewAgent: Optional[int] = None adversaries: tuple = field(default_factory=tuple) balls: tuple = field(default_factory=tuple) boxes: tuple = field(default_factory=tuple) keys: tuple = field(default_factory=tuple) doors: tuple = field(default_factory=tuple) lockeddoors: tuple = field(default_factory=tuple) def _view(self) -> bool: return self.viewAgent is not None def csv(self, delim=",") -> str: if self._view(): return delim.join([f"{self.colAgent}", f"{self.rowAgent}", f"{self.viewAgent}"] + \ list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries]))) else: return delim.join([f"{self.colAgent}", f"{self.rowAgent}"] + \ list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries]))) def tuple(self) -> tuple: if self._view(): return (self.colAgent, self.rowAgent, self.viewAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row, adv.view] for adv in self.adversaries]))) else: return (self.colAgent, self.rowAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row] for adv in self.adversaries]))) def adversary_distance(self, adversary) -> int: manhatten_dist = abs(self.colAgent - adversary.col) + abs(self.rowAgent - adversary.row) vertical = self.colAgent - adversary.col horizontal = self.rowAgent - adversary.row if horizontal == 0: #Needs to move in a line if adversary.view == 0 or adversary.view == 2: # Needs to move south or north, but looks east or west manhatten_dist += 1 elif vertical > 0 and adversary.view == 1: # Needs to move north, but looks south manhatten_dist += 2 elif vertical < 0 and adversary.view == 3: # Needs to move south, but looks north manhatten_dist += 2 elif vertical == 0: # Needs to move in a line if adversary.view == 1 or adversary.view == 3: # Looks north or south manhatten_dist += 1 elif horizontal > 0 and adversary.view == 2: # needs to move east, but looks west manhatten_dist += 2 elif horizontal < 0 and adversary.view == 0: # needs to west, but looks east manhatten_dist += 2 else: # Needs to move in two lines if horizontal > 0 and self.viewAgent == 2: #Needs to move east , but looks west manhatten_dist += 1 if horizontal < 0 and self.viewAgent == 1: # needs to move west, but looks east manhatten_dist += 1 if vertical < 0 and self.viewAgent == 3: # needs to move south, but looks north manhatten_dist += 1 if vertical > 0 and self.viewAgent == 1: # needs to move north, but looks south manhatten_dist += 1 return manhatten_dist @property def feature_space(self) -> list[str]: if self._view(): return ["colAgent", "rowAgent", "viewAgent"] + list(itertools.chain.from_iterable( [[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries] )) else: return ["colAgent", "rowAgent"] + list(itertools.chain.from_iterable( [[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries] )) def to_state(ints, booleans): ints = {key:int(value) for key, value in ints.items()} any_carrying = dict() for formula, value in booleans.items(): if not value: continue if "Carrying" in formula: pos = formula.find("Carrying") l = len("Carrying") any_carrying[formula[0:pos]] = formula[pos+l:] if "viewAgent" in ints: agentState = (ints["colAgent"], ints["rowAgent"], ints["viewAgent"], any_carrying.get("Agent", "")) else: agentState = (ints["colAgent"], ints["rowAgent"], any_carrying.get("Agent", "")) adversaries = tuple() boxes = tuple() balls = tuple() keys = tuple() lockeddoors = tuple() doors = tuple() for color in COLOR_NAMES: color = color.capitalize() if "col" + color in ints: if "view" + color in ints: adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], ints["view"+color], carrying=any_carrying.get(color, "")),) else: adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], carrying=any_carrying.get(color, "")),) if "col" + color + "Box" in ints: pass if "col" + color + "Key" in ints: identifier = color + "Key" balls += (KeyState(color, ints["col"+identifier], ints["row"+identifier]),) if color + "DoorOpen" in booleans: if booleans[color + "DoorOpen"]: doors += (DoorState(color, locked=False),) else: doors += (DoorState(color, locked=True),) elif color + "LockedDoorOpen" in booleans: assert False return State(*agentState, adversaries=adversaries, doors=doors)