You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
6.3 KiB
163 lines
6.3 KiB
from dataclasses import dataclass, field
|
|
import itertools
|
|
from typing import Optional
|
|
|
|
from minigrid.core.constants import (
|
|
COLOR_NAMES,
|
|
IDX_TO_COLOR
|
|
)
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class KeyState:
|
|
color: str = ""
|
|
col: int = 0
|
|
row: int = 0
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class BallState:
|
|
color: str = ""
|
|
col: int = 0
|
|
row: int = 0
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class BoxState:
|
|
color: str = ""
|
|
col: int = 0
|
|
row: int = 0
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class DoorState:
|
|
color: str = ""
|
|
locked: bool = True
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class AdversaryState:
|
|
color: str = ""
|
|
col: int = 0
|
|
row: int = 0
|
|
view: Optional[int] = None
|
|
carrying: str = ""
|
|
|
|
def csv(self,delim=",") -> str:
|
|
if self.view is not None:
|
|
return f"{self.col}{delim}{self.row}{delim}{self.view}"
|
|
else:
|
|
return f"{self.col}{delim}{self.row}"
|
|
|
|
@dataclass(frozen=True, eq=True)
|
|
class State:
|
|
colAgent: int
|
|
rowAgent: int
|
|
carrying: Optional[str] = None
|
|
viewAgent: Optional[int] = None
|
|
adversaries: tuple = field(default_factory=tuple)
|
|
balls: tuple = field(default_factory=tuple)
|
|
boxes: tuple = field(default_factory=tuple)
|
|
keys: tuple = field(default_factory=tuple)
|
|
doors: tuple = field(default_factory=tuple)
|
|
lockeddoors: tuple = field(default_factory=tuple)
|
|
|
|
def _view(self) -> bool:
|
|
return self.viewAgent is not None
|
|
|
|
def csv(self, delim=",") -> str:
|
|
if self._view():
|
|
return delim.join([f"{self.colAgent}", f"{self.rowAgent}", f"{self.viewAgent}"] + \
|
|
list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries])))
|
|
else:
|
|
return delim.join([f"{self.colAgent}", f"{self.rowAgent}"] + \
|
|
list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries])))
|
|
|
|
|
|
def tuple(self) -> tuple:
|
|
if self._view():
|
|
return (self.colAgent, self.rowAgent, self.viewAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row, adv.view] for adv in self.adversaries])))
|
|
else:
|
|
return (self.colAgent, self.rowAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row] for adv in self.adversaries])))
|
|
|
|
def adversary_distance(self, adversary) -> int:
|
|
manhatten_dist = abs(self.colAgent - adversary.col) + abs(self.rowAgent - adversary.row)
|
|
vertical = self.colAgent - adversary.col
|
|
horizontal = self.rowAgent - adversary.row
|
|
if horizontal == 0: #Needs to move in a line
|
|
if adversary.view == 0 or adversary.view == 2: # Needs to move south or north, but looks east or west
|
|
manhatten_dist += 1
|
|
elif vertical > 0 and adversary.view == 1: # Needs to move north, but looks south
|
|
manhatten_dist += 2
|
|
elif vertical < 0 and adversary.view == 3: # Needs to move south, but looks north
|
|
manhatten_dist += 2
|
|
elif vertical == 0: # Needs to move in a line
|
|
if adversary.view == 1 or adversary.view == 3: # Looks north or south
|
|
manhatten_dist += 1
|
|
elif horizontal > 0 and adversary.view == 2: # needs to move east, but looks west
|
|
manhatten_dist += 2
|
|
elif horizontal < 0 and adversary.view == 0: # needs to west, but looks east
|
|
manhatten_dist += 2
|
|
else: # Needs to move in two lines
|
|
if horizontal > 0 and self.viewAgent == 2: #Needs to move east , but looks west
|
|
manhatten_dist += 1
|
|
if horizontal < 0 and self.viewAgent == 1: # needs to move west, but looks east
|
|
manhatten_dist += 1
|
|
if vertical < 0 and self.viewAgent == 3: # needs to move south, but looks north
|
|
manhatten_dist += 1
|
|
if vertical > 0 and self.viewAgent == 1: # needs to move north, but looks south
|
|
manhatten_dist += 1
|
|
return manhatten_dist
|
|
|
|
|
|
|
|
|
|
@property
|
|
def feature_space(self) -> list[str]:
|
|
if self._view():
|
|
return ["colAgent", "rowAgent", "viewAgent"] + list(itertools.chain.from_iterable(
|
|
[[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries]
|
|
))
|
|
else:
|
|
return ["colAgent", "rowAgent"] + list(itertools.chain.from_iterable(
|
|
[[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries]
|
|
))
|
|
|
|
|
|
|
|
def to_state(ints, booleans):
|
|
ints = {key:int(value) for key, value in ints.items()}
|
|
any_carrying = dict()
|
|
for formula, value in booleans.items():
|
|
if not value: continue
|
|
if "Carrying" in formula:
|
|
pos = formula.find("Carrying")
|
|
l = len("Carrying")
|
|
any_carrying[formula[0:pos]] = formula[pos+l:]
|
|
if "viewAgent" in ints:
|
|
agentState = (ints["colAgent"], ints["rowAgent"], ints["viewAgent"], any_carrying.get("Agent", ""))
|
|
else:
|
|
agentState = (ints["colAgent"], ints["rowAgent"], any_carrying.get("Agent", ""))
|
|
adversaries = tuple()
|
|
boxes = tuple()
|
|
balls = tuple()
|
|
keys = tuple()
|
|
lockeddoors = tuple()
|
|
doors = tuple()
|
|
for color in COLOR_NAMES:
|
|
color = color.capitalize()
|
|
if "col" + color in ints:
|
|
if "view" + color in ints:
|
|
adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], ints["view"+color], carrying=any_carrying.get(color, "")),)
|
|
else:
|
|
adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], carrying=any_carrying.get(color, "")),)
|
|
if "col" + color + "Box" in ints:
|
|
pass
|
|
if "col" + color + "Key" in ints:
|
|
identifier = color + "Key"
|
|
balls += (KeyState(color, ints["col"+identifier], ints["row"+identifier]),)
|
|
if color + "DoorOpen" in booleans:
|
|
if booleans[color + "DoorOpen"]:
|
|
doors += (DoorState(color, locked=False),)
|
|
else:
|
|
doors += (DoorState(color, locked=True),)
|
|
elif color + "LockedDoorOpen" in booleans:
|
|
assert False
|
|
|
|
|
|
return State(*agentState, adversaries=adversaries, doors=doors)
|