You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.3 KiB
81 lines
2.3 KiB
from minigrid.core.world_object import WorldObj
|
|
from minigrid.core.tasks import DoRandom, TaskManager
|
|
import numpy as np
|
|
import math
|
|
|
|
from minigrid.core.state import AdversaryState
|
|
|
|
|
|
from minigrid.core.constants import (
|
|
COLOR_TO_IDX,
|
|
COLORS,
|
|
IDX_TO_COLOR,
|
|
IDX_TO_OBJECT,
|
|
OBJECT_TO_IDX,
|
|
DIR_TO_VEC
|
|
)
|
|
from minigrid.utils.rendering import (
|
|
fill_coords,
|
|
point_in_triangle,
|
|
rotate_fn,
|
|
)
|
|
|
|
VIEW_TO_STATE_IDX = {
|
|
0: 3,
|
|
1: 4,
|
|
2: 5,
|
|
3: 6
|
|
}
|
|
|
|
|
|
class Adversary(WorldObj):
|
|
def __init__(self, adversary_pos, adversary_dir=1, color="blue", tasks=[DoRandom()], repeating=False):
|
|
super().__init__("adversary", color)
|
|
self.adversary_pos = adversary_pos # TODO
|
|
self.adversary_dir = adversary_dir # TODO
|
|
self.color = color
|
|
self.rgb = COLORS[self.color]
|
|
self.task_manager = TaskManager(tasks, repeating=repeating)
|
|
self.carrying = None
|
|
self.name = color.capitalize()
|
|
|
|
def render(self, img):
|
|
tri_fn = point_in_triangle(
|
|
(0.12, 0.19),
|
|
(0.87, 0.50),
|
|
(0.12, 0.81),
|
|
)
|
|
|
|
# Rotate the agent based on its direction
|
|
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * self.adversary_dir)
|
|
fill_coords(img, tri_fn, COLORS[self.color])
|
|
|
|
def dir_vec(self):
|
|
assert self.adversary_dir >= 0 and self.adversary_dir < 4
|
|
return DIR_TO_VEC[self.adversary_dir]
|
|
|
|
def can_overlap(self):
|
|
return True
|
|
|
|
def encode(self):
|
|
return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], VIEW_TO_STATE_IDX[self.adversary_dir])
|
|
|
|
def append_task(self, task):
|
|
self.task_manager.tasks.append(task)
|
|
|
|
def insert_task(self, task, position):
|
|
self.task_manager.tasks.insert(position, task)
|
|
|
|
def to_state(self):
|
|
color = self.color.capitalize()
|
|
if self.carrying:
|
|
carrying = f"{self.carrying.color.capitalize()}{self.carrying.type.capitalize()}"
|
|
else:
|
|
carrying = ""
|
|
return AdversaryState(color=color, col=self.adversary_pos[0], row=self.adversary_pos[1], view=self.adversary_dir, carrying=carrying)
|
|
|
|
def get_action(self, env):
|
|
return self.task_manager.get_best_action(self.adversary_pos, self.adversary_dir, self.carrying, env)
|
|
|
|
def __str__(self):
|
|
return f"({self.adversary_pos}, {self.adversary_dir})"
|