You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
364 lines
11 KiB
364 lines
11 KiB
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import Any, Callable
|
|
|
|
import numpy as np
|
|
|
|
from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
|
|
from minigrid.core.world_object import Wall, WorldObj
|
|
from minigrid.utils.rendering import (
|
|
downsample,
|
|
fill_coords,
|
|
highlight_img,
|
|
point_in_rect,
|
|
point_in_triangle,
|
|
rotate_fn,
|
|
)
|
|
|
|
|
|
class Grid:
|
|
"""
|
|
Represent a grid and operations on it
|
|
"""
|
|
|
|
# Static cache of pre-renderer tiles
|
|
tile_cache: dict[tuple[Any, ...], Any] = {}
|
|
|
|
def __init__(self, width: int, height: int):
|
|
assert width >= 3
|
|
assert height >= 3
|
|
|
|
self.width: int = width
|
|
self.height: int = height
|
|
|
|
self.grid: list[WorldObj | None] = [None] * (width * height)
|
|
self.background = [None] * width * height
|
|
|
|
|
|
def __contains__(self, key: Any) -> bool:
|
|
if isinstance(key, WorldObj):
|
|
for e in self.grid:
|
|
if e is key:
|
|
return True
|
|
elif isinstance(key, tuple):
|
|
for e in self.grid:
|
|
if e is None:
|
|
continue
|
|
if (e.color, e.type) == key:
|
|
return True
|
|
if key[0] is None and key[1] == e.type:
|
|
return True
|
|
return False
|
|
|
|
def __eq__(self, other: Grid) -> bool:
|
|
grid1 = self.encode()
|
|
grid2 = other.encode()
|
|
return np.array_equal(grid2, grid1)
|
|
|
|
def __ne__(self, other: Grid) -> bool:
|
|
return not self == other
|
|
|
|
def copy(self) -> Grid:
|
|
from copy import deepcopy
|
|
|
|
return deepcopy(self)
|
|
|
|
def set(self, i: int, j: int, v: WorldObj | None):
|
|
assert (
|
|
0 <= i < self.width
|
|
), f"column index {i} outside of grid of width {self.width}"
|
|
assert (
|
|
0 <= j < self.height
|
|
), f"row index {j} outside of grid of height {self.height}"
|
|
self.grid[j * self.width + i] = v
|
|
|
|
def get(self, i: int, j: int) -> WorldObj | None:
|
|
assert 0 <= i < self.width
|
|
assert 0 <= j < self.height
|
|
assert self.grid is not None
|
|
return self.grid[j * self.width + i]
|
|
|
|
|
|
def set_background(self, i, j, v):
|
|
assert i >= 0 and i < self.width
|
|
assert j >= 0 and j < self.height
|
|
self.background[j * self.width + i] = v
|
|
|
|
def get_background(self, i, j):
|
|
assert i >= 0 and i < self.width
|
|
assert j >= 0 and j < self.height
|
|
return self.background[j * self.width + i]
|
|
|
|
|
|
def horz_wall(
|
|
self,
|
|
x: int,
|
|
y: int,
|
|
length: int | None = None,
|
|
obj_type: Callable[[], WorldObj] = Wall,
|
|
):
|
|
if length is None:
|
|
length = self.width - x
|
|
for i in range(0, length):
|
|
try:
|
|
self.set(x + i, y, obj_type())
|
|
except TypeError:
|
|
self.set(x + i, y, obj_type)
|
|
|
|
def vert_wall(
|
|
self,
|
|
x: int,
|
|
y: int,
|
|
length: int | None = None,
|
|
obj_type: Callable[[], WorldObj] = Wall,
|
|
):
|
|
if length is None:
|
|
length = self.height - y
|
|
for j in range(0, length):
|
|
try:
|
|
self.set(x, y + j, obj_type())
|
|
except TypeError:
|
|
self.set(x, y + j, obj_type)
|
|
|
|
|
|
def wall_rect(self, x: int, y: int, w: int, h: int):
|
|
self.horz_wall(x, y, w)
|
|
self.horz_wall(x, y + h - 1, w)
|
|
self.vert_wall(x, y, h)
|
|
self.vert_wall(x + w - 1, y, h)
|
|
|
|
def rotate_left(self) -> Grid:
|
|
"""
|
|
Rotate the grid to the left (counter-clockwise)
|
|
"""
|
|
|
|
grid = Grid(self.height, self.width)
|
|
|
|
for i in range(self.width):
|
|
for j in range(self.height):
|
|
v = self.get(i, j)
|
|
grid.set(j, grid.height - 1 - i, v)
|
|
|
|
return grid
|
|
|
|
def slice(self, topX: int, topY: int, width: int, height: int) -> Grid:
|
|
"""
|
|
Get a subset of the grid
|
|
"""
|
|
|
|
grid = Grid(width, height)
|
|
|
|
for j in range(0, height):
|
|
for i in range(0, width):
|
|
x = topX + i
|
|
y = topY + j
|
|
|
|
if 0 <= x < self.width and 0 <= y < self.height:
|
|
v = self.get(x, y)
|
|
else:
|
|
v = Wall()
|
|
|
|
grid.set(i, j, v)
|
|
|
|
return grid
|
|
|
|
@classmethod
|
|
def render_tile(
|
|
cls,
|
|
obj: WorldObj | None,
|
|
agent_dir: int | None = None,
|
|
adversaries: list = [],
|
|
highlight: bool = False,
|
|
tile_size: int = TILE_PIXELS,
|
|
subdivs: int = 3,
|
|
) -> np.ndarray:
|
|
"""
|
|
Render a tile and cache the result
|
|
"""
|
|
|
|
# Hash map lookup key for the cache
|
|
key: tuple[Any, ...] = (agent_dir, highlight, tile_size)
|
|
for adversary in adversaries:
|
|
key += (adversary.adversary_dir, adversary.color)
|
|
key = obj.encode() + key if obj else key
|
|
|
|
if key in cls.tile_cache:
|
|
return cls.tile_cache[key]
|
|
|
|
img = np.zeros(
|
|
shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
|
|
)
|
|
|
|
# Draw the grid lines (top and left edges)
|
|
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
|
|
fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
|
|
|
|
if obj is not None:
|
|
obj.render(img)
|
|
|
|
# Overlay the agent on top
|
|
tri_fn = point_in_triangle(
|
|
(0.12, 0.19),
|
|
(0.87, 0.50),
|
|
(0.12, 0.81),
|
|
)
|
|
if agent_dir is not None:
|
|
# Rotate the agent based on its direction
|
|
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
|
|
fill_coords(img, tri_fn, (255, 0, 0))
|
|
for adversary in adversaries:
|
|
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * adversary.adversary_dir)
|
|
fill_coords(img, tri_fn, adversary.rgb)
|
|
|
|
# Highlight the cell if needed
|
|
if highlight:
|
|
highlight_img(img)
|
|
|
|
# Downsample the image to perform supersampling/anti-aliasing
|
|
img = downsample(img, subdivs)
|
|
|
|
# Cache the rendered tile
|
|
cls.tile_cache[key] = img
|
|
|
|
return img
|
|
|
|
def render(
|
|
self,
|
|
tile_size: int,
|
|
agent_pos: tuple[int, int],
|
|
agent_dir: int | None = None,
|
|
adversaries: list = [],
|
|
highlight_mask: np.ndarray | None = None,
|
|
) -> np.ndarray:
|
|
"""
|
|
Render this grid at a given scale
|
|
:param r: target renderer object
|
|
:param tile_size: tile size in pixels
|
|
"""
|
|
|
|
if highlight_mask is None:
|
|
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
|
# Compute the total grid size
|
|
width_px = self.width * tile_size
|
|
height_px = self.height * tile_size
|
|
|
|
img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
|
|
|
|
# Render the grid
|
|
present_adversaries = list()
|
|
for j in range(0, self.height):
|
|
for i in range(0, self.width):
|
|
present_adversaries.clear()
|
|
cell = self.get(i, j)
|
|
|
|
agent_here = np.array_equal(agent_pos, (i, j))
|
|
for adversary in adversaries:
|
|
if np.array_equal(adversary.adversary_pos, (i, j)):
|
|
present_adversaries.append(adversary)
|
|
assert highlight_mask is not None
|
|
tile_img = Grid.render_tile(
|
|
cell,
|
|
agent_dir=agent_dir if agent_here else None,
|
|
adversaries=present_adversaries,
|
|
highlight=highlight_mask[i, j],
|
|
tile_size=tile_size,
|
|
)
|
|
|
|
ymin = j * tile_size
|
|
ymax = (j + 1) * tile_size
|
|
xmin = i * tile_size
|
|
xmax = (i + 1) * tile_size
|
|
img[ymin:ymax, xmin:xmax, :] = tile_img
|
|
|
|
if len(present_adversaries) > 0:
|
|
adversaries = [a for a in adversaries if a not in present_adversaries]
|
|
|
|
return img
|
|
|
|
def encode(self, vis_mask: np.ndarray | None = None) -> np.ndarray:
|
|
"""
|
|
Produce a compact numpy encoding of the grid
|
|
"""
|
|
|
|
if vis_mask is None:
|
|
vis_mask = np.ones((self.width, self.height), dtype=bool)
|
|
|
|
array = np.zeros((self.width, self.height, 3), dtype="uint8")
|
|
|
|
for i in range(self.width):
|
|
for j in range(self.height):
|
|
assert vis_mask is not None
|
|
if vis_mask[i, j]:
|
|
v = self.get(i, j)
|
|
|
|
if v is None:
|
|
array[i, j, 0] = OBJECT_TO_IDX["empty"]
|
|
array[i, j, 1] = 0
|
|
array[i, j, 2] = 0
|
|
|
|
else:
|
|
array[i, j, :] = v.encode()
|
|
|
|
return array
|
|
|
|
@staticmethod
|
|
def decode(array: np.ndarray) -> tuple[Grid, np.ndarray]:
|
|
"""
|
|
Decode an array grid encoding back into a grid
|
|
"""
|
|
|
|
width, height, channels = array.shape
|
|
assert channels == 3
|
|
|
|
vis_mask = np.ones(shape=(width, height), dtype=bool)
|
|
|
|
grid = Grid(width, height)
|
|
for i in range(width):
|
|
for j in range(height):
|
|
type_idx, color_idx, state = array[i, j]
|
|
v = WorldObj.decode(type_idx, color_idx, state)
|
|
grid.set(i, j, v)
|
|
vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
|
|
|
|
return grid, vis_mask
|
|
|
|
def process_vis(self, agent_pos: tuple[int, int]) -> np.ndarray:
|
|
mask = np.zeros(shape=(self.width, self.height), dtype=bool)
|
|
|
|
mask[agent_pos[0], agent_pos[1]] = True
|
|
|
|
for j in reversed(range(0, self.height)):
|
|
for i in range(0, self.width - 1):
|
|
if not mask[i, j]:
|
|
continue
|
|
|
|
cell = self.get(i, j)
|
|
if cell and not cell.see_behind():
|
|
continue
|
|
|
|
mask[i + 1, j] = True
|
|
if j > 0:
|
|
mask[i + 1, j - 1] = True
|
|
mask[i, j - 1] = True
|
|
|
|
for i in reversed(range(1, self.width)):
|
|
if not mask[i, j]:
|
|
continue
|
|
|
|
cell = self.get(i, j)
|
|
if cell and not cell.see_behind():
|
|
continue
|
|
|
|
mask[i - 1, j] = True
|
|
if j > 0:
|
|
mask[i - 1, j - 1] = True
|
|
mask[i, j - 1] = True
|
|
|
|
for j in range(0, self.height):
|
|
for i in range(0, self.width):
|
|
if not mask[i, j]:
|
|
self.set(i, j, None)
|
|
|
|
return mask
|