You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

163 lines
6.3 KiB

from dataclasses import dataclass, field
import itertools
from typing import Optional
from minigrid.core.constants import (
COLOR_NAMES,
IDX_TO_COLOR
)
@dataclass(frozen=True, eq=True)
class KeyState:
color: str = ""
col: int = 0
row: int = 0
@dataclass(frozen=True, eq=True)
class BallState:
color: str = ""
col: int = 0
row: int = 0
@dataclass(frozen=True, eq=True)
class BoxState:
color: str = ""
col: int = 0
row: int = 0
@dataclass(frozen=True, eq=True)
class DoorState:
color: str = ""
locked: bool = True
@dataclass(frozen=True, eq=True)
class AdversaryState:
color: str = ""
col: int = 0
row: int = 0
view: Optional[int] = None
carrying: str = ""
def csv(self,delim=",") -> str:
if self.view is not None:
return f"{self.col}{delim}{self.row}{delim}{self.view}"
else:
return f"{self.col}{delim}{self.row}"
@dataclass(frozen=True, eq=True)
class State:
colAgent: int
rowAgent: int
carrying: Optional[str] = None
viewAgent: Optional[int] = None
adversaries: tuple = field(default_factory=tuple)
balls: tuple = field(default_factory=tuple)
boxes: tuple = field(default_factory=tuple)
keys: tuple = field(default_factory=tuple)
doors: tuple = field(default_factory=tuple)
lockeddoors: tuple = field(default_factory=tuple)
def _view(self) -> bool:
return self.viewAgent is not None
def csv(self, delim=",") -> str:
if self._view():
return delim.join([f"{self.colAgent}", f"{self.rowAgent}", f"{self.viewAgent}"] + \
list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries])))
else:
return delim.join([f"{self.colAgent}", f"{self.rowAgent}"] + \
list(itertools.chain.from_iterable([[f"{adv.col}", f"{adv.row}", f"{adv.view}"] for adv in self.adversaries])))
def tuple(self) -> tuple:
if self._view():
return (self.colAgent, self.rowAgent, self.viewAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row, adv.view] for adv in self.adversaries])))
else:
return (self.colAgent, self.rowAgent) + tuple(list(itertools.chain.from_iterable([[adv.col, adv.row] for adv in self.adversaries])))
def adversary_distance(self, adversary) -> int:
manhatten_dist = abs(self.colAgent - adversary.col) + abs(self.rowAgent - adversary.row)
vertical = self.colAgent - adversary.col
horizontal = self.rowAgent - adversary.row
if horizontal == 0: #Needs to move in a line
if adversary.view == 0 or adversary.view == 2: # Needs to move south or north, but looks east or west
manhatten_dist += 1
elif vertical > 0 and adversary.view == 1: # Needs to move north, but looks south
manhatten_dist += 2
elif vertical < 0 and adversary.view == 3: # Needs to move south, but looks north
manhatten_dist += 2
elif vertical == 0: # Needs to move in a line
if adversary.view == 1 or adversary.view == 3: # Looks north or south
manhatten_dist += 1
elif horizontal > 0 and adversary.view == 2: # needs to move east, but looks west
manhatten_dist += 2
elif horizontal < 0 and adversary.view == 0: # needs to west, but looks east
manhatten_dist += 2
else: # Needs to move in two lines
if horizontal > 0 and self.viewAgent == 2: #Needs to move east , but looks west
manhatten_dist += 1
if horizontal < 0 and self.viewAgent == 1: # needs to move west, but looks east
manhatten_dist += 1
if vertical < 0 and self.viewAgent == 3: # needs to move south, but looks north
manhatten_dist += 1
if vertical > 0 and self.viewAgent == 1: # needs to move north, but looks south
manhatten_dist += 1
return manhatten_dist
@property
def feature_space(self) -> list[str]:
if self._view():
return ["colAgent", "rowAgent", "viewAgent"] + list(itertools.chain.from_iterable(
[[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries]
))
else:
return ["colAgent", "rowAgent"] + list(itertools.chain.from_iterable(
[[f"col{adv.color.capitalize()}", f"row{adv.color.capitalize()}", f"view{adv.color.capitalize()}"] for adv in self.adversaries]
))
def to_state(ints, booleans):
ints = {key:int(value) for key, value in ints.items()}
any_carrying = dict()
for formula, value in booleans.items():
if not value: continue
if "Carrying" in formula:
pos = formula.find("Carrying")
l = len("Carrying")
any_carrying[formula[0:pos]] = formula[pos+l:]
if "viewAgent" in ints:
agentState = (ints["colAgent"], ints["rowAgent"], ints["viewAgent"], any_carrying.get("Agent", ""))
else:
agentState = (ints["colAgent"], ints["rowAgent"], any_carrying.get("Agent", ""))
adversaries = tuple()
boxes = tuple()
balls = tuple()
keys = tuple()
lockeddoors = tuple()
doors = tuple()
for color in COLOR_NAMES:
color = color.capitalize()
if "col" + color in ints:
if "view" + color in ints:
adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], ints["view"+color], carrying=any_carrying.get(color, "")),)
else:
adversaries += (AdversaryState(color, ints["col"+color], ints["row"+color], carrying=any_carrying.get(color, "")),)
if "col" + color + "Box" in ints:
pass
if "col" + color + "Key" in ints:
identifier = color + "Key"
balls += (KeyState(color, ints["col"+identifier], ints["row"+identifier]),)
if color + "DoorOpen" in booleans:
if booleans[color + "DoorOpen"]:
doors += (DoorState(color, locked=False),)
else:
doors += (DoorState(color, locked=True),)
elif color + "LockedDoorOpen" in booleans:
assert False
return State(*agentState, adversaries=adversaries, doors=doors)