You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

364 lines
11 KiB

from __future__ import annotations
import math
from typing import Any, Callable
import numpy as np
from minigrid.core.constants import OBJECT_TO_IDX, TILE_PIXELS
from minigrid.core.world_object import Wall, WorldObj
from minigrid.utils.rendering import (
downsample,
fill_coords,
highlight_img,
point_in_rect,
point_in_triangle,
rotate_fn,
)
class Grid:
"""
Represent a grid and operations on it
"""
# Static cache of pre-renderer tiles
tile_cache: dict[tuple[Any, ...], Any] = {}
def __init__(self, width: int, height: int):
assert width >= 3
assert height >= 3
self.width: int = width
self.height: int = height
self.grid: list[WorldObj | None] = [None] * (width * height)
self.background = [None] * width * height
def __contains__(self, key: Any) -> bool:
if isinstance(key, WorldObj):
for e in self.grid:
if e is key:
return True
elif isinstance(key, tuple):
for e in self.grid:
if e is None:
continue
if (e.color, e.type) == key:
return True
if key[0] is None and key[1] == e.type:
return True
return False
def __eq__(self, other: Grid) -> bool:
grid1 = self.encode()
grid2 = other.encode()
return np.array_equal(grid2, grid1)
def __ne__(self, other: Grid) -> bool:
return not self == other
def copy(self) -> Grid:
from copy import deepcopy
return deepcopy(self)
def set(self, i: int, j: int, v: WorldObj | None):
assert (
0 <= i < self.width
), f"column index {i} outside of grid of width {self.width}"
assert (
0 <= j < self.height
), f"row index {j} outside of grid of height {self.height}"
self.grid[j * self.width + i] = v
def get(self, i: int, j: int) -> WorldObj | None:
assert 0 <= i < self.width
assert 0 <= j < self.height
assert self.grid is not None
return self.grid[j * self.width + i]
def set_background(self, i, j, v):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
self.background[j * self.width + i] = v
def get_background(self, i, j):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
return self.background[j * self.width + i]
def horz_wall(
self,
x: int,
y: int,
length: int | None = None,
obj_type: Callable[[], WorldObj] = Wall,
):
if length is None:
length = self.width - x
for i in range(0, length):
try:
self.set(x + i, y, obj_type())
except TypeError:
self.set(x + i, y, obj_type)
def vert_wall(
self,
x: int,
y: int,
length: int | None = None,
obj_type: Callable[[], WorldObj] = Wall,
):
if length is None:
length = self.height - y
for j in range(0, length):
try:
self.set(x, y + j, obj_type())
except TypeError:
self.set(x, y + j, obj_type)
def wall_rect(self, x: int, y: int, w: int, h: int):
self.horz_wall(x, y, w)
self.horz_wall(x, y + h - 1, w)
self.vert_wall(x, y, h)
self.vert_wall(x + w - 1, y, h)
def rotate_left(self) -> Grid:
"""
Rotate the grid to the left (counter-clockwise)
"""
grid = Grid(self.height, self.width)
for i in range(self.width):
for j in range(self.height):
v = self.get(i, j)
grid.set(j, grid.height - 1 - i, v)
return grid
def slice(self, topX: int, topY: int, width: int, height: int) -> Grid:
"""
Get a subset of the grid
"""
grid = Grid(width, height)
for j in range(0, height):
for i in range(0, width):
x = topX + i
y = topY + j
if 0 <= x < self.width and 0 <= y < self.height:
v = self.get(x, y)
else:
v = Wall()
grid.set(i, j, v)
return grid
@classmethod
def render_tile(
cls,
obj: WorldObj | None,
agent_dir: int | None = None,
adversaries: list = [],
highlight: bool = False,
tile_size: int = TILE_PIXELS,
subdivs: int = 3,
) -> np.ndarray:
"""
Render a tile and cache the result
"""
# Hash map lookup key for the cache
key: tuple[Any, ...] = (agent_dir, highlight, tile_size)
for adversary in adversaries:
key += (adversary.adversary_dir, adversary.color)
key = obj.encode() + key if obj else key
if key in cls.tile_cache:
return cls.tile_cache[key]
img = np.zeros(
shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
)
# Draw the grid lines (top and left edges)
fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
if obj is not None:
obj.render(img)
# Overlay the agent on top
tri_fn = point_in_triangle(
(0.12, 0.19),
(0.87, 0.50),
(0.12, 0.81),
)
if agent_dir is not None:
# Rotate the agent based on its direction
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
fill_coords(img, tri_fn, (255, 0, 0))
for adversary in adversaries:
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * adversary.adversary_dir)
fill_coords(img, tri_fn, adversary.rgb)
# Highlight the cell if needed
if highlight:
highlight_img(img)
# Downsample the image to perform supersampling/anti-aliasing
img = downsample(img, subdivs)
# Cache the rendered tile
cls.tile_cache[key] = img
return img
def render(
self,
tile_size: int,
agent_pos: tuple[int, int],
agent_dir: int | None = None,
adversaries: list = [],
highlight_mask: np.ndarray | None = None,
) -> np.ndarray:
"""
Render this grid at a given scale
:param r: target renderer object
:param tile_size: tile size in pixels
"""
if highlight_mask is None:
highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
# Compute the total grid size
width_px = self.width * tile_size
height_px = self.height * tile_size
img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
# Render the grid
present_adversaries = list()
for j in range(0, self.height):
for i in range(0, self.width):
present_adversaries.clear()
cell = self.get(i, j)
agent_here = np.array_equal(agent_pos, (i, j))
for adversary in adversaries:
if np.array_equal(adversary.adversary_pos, (i, j)):
present_adversaries.append(adversary)
assert highlight_mask is not None
tile_img = Grid.render_tile(
cell,
agent_dir=agent_dir if agent_here else None,
adversaries=present_adversaries,
highlight=highlight_mask[i, j],
tile_size=tile_size,
)
ymin = j * tile_size
ymax = (j + 1) * tile_size
xmin = i * tile_size
xmax = (i + 1) * tile_size
img[ymin:ymax, xmin:xmax, :] = tile_img
if len(present_adversaries) > 0:
adversaries = [a for a in adversaries if a not in present_adversaries]
return img
def encode(self, vis_mask: np.ndarray | None = None) -> np.ndarray:
"""
Produce a compact numpy encoding of the grid
"""
if vis_mask is None:
vis_mask = np.ones((self.width, self.height), dtype=bool)
array = np.zeros((self.width, self.height, 3), dtype="uint8")
for i in range(self.width):
for j in range(self.height):
assert vis_mask is not None
if vis_mask[i, j]:
v = self.get(i, j)
if v is None:
array[i, j, 0] = OBJECT_TO_IDX["empty"]
array[i, j, 1] = 0
array[i, j, 2] = 0
else:
array[i, j, :] = v.encode()
return array
@staticmethod
def decode(array: np.ndarray) -> tuple[Grid, np.ndarray]:
"""
Decode an array grid encoding back into a grid
"""
width, height, channels = array.shape
assert channels == 3
vis_mask = np.ones(shape=(width, height), dtype=bool)
grid = Grid(width, height)
for i in range(width):
for j in range(height):
type_idx, color_idx, state = array[i, j]
v = WorldObj.decode(type_idx, color_idx, state)
grid.set(i, j, v)
vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
return grid, vis_mask
def process_vis(self, agent_pos: tuple[int, int]) -> np.ndarray:
mask = np.zeros(shape=(self.width, self.height), dtype=bool)
mask[agent_pos[0], agent_pos[1]] = True
for j in reversed(range(0, self.height)):
for i in range(0, self.width - 1):
if not mask[i, j]:
continue
cell = self.get(i, j)
if cell and not cell.see_behind():
continue
mask[i + 1, j] = True
if j > 0:
mask[i + 1, j - 1] = True
mask[i, j - 1] = True
for i in reversed(range(1, self.width)):
if not mask[i, j]:
continue
cell = self.get(i, j)
if cell and not cell.see_behind():
continue
mask[i - 1, j] = True
if j > 0:
mask[i - 1, j - 1] = True
mask[i, j - 1] = True
for j in range(0, self.height):
for i in range(0, self.width):
if not mask[i, j]:
self.set(i, j, None)
return mask