The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

1112 lines
36 KiB

4 weeks ago
  1. from __future__ import annotations
  2. import hashlib
  3. import math
  4. from abc import abstractmethod
  5. from typing import Any, Iterable, SupportsFloat, TypeVar
  6. import numpy
  7. import gymnasium as gym
  8. import numpy as np
  9. import pygame
  10. import pygame.freetype
  11. from gymnasium import spaces
  12. from gymnasium.core import ActType, ObsType
  13. from minigrid.core.actions import Actions
  14. from minigrid.core.constants import COLOR_NAMES, DIR_TO_VEC, TILE_PIXELS, OBJECT_TO_STR
  15. from minigrid.core.grid import Grid
  16. from minigrid.core.mission import MissionSpace
  17. from minigrid.core.world_object import Point, WorldObj, Slippery, SlipperyEast, SlipperyNorth, SlipperySouth, SlipperyWest, Lava, SlipperyNorthWest, SlipperyNorthEast, SlipperySouthWest, SlipperySouthEast
  18. from minigrid.core.adversary import Adversary
  19. from minigrid.core.tasks import DoRandom, Task, List
  20. from minigrid.core.state import State
  21. from collections import deque
  22. T = TypeVar("T")
  23. stay_at_pos_distribution = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0]
  24. def is_slippery(cell : WorldObj):
  25. return isinstance(cell, (SlipperySouth, Slippery, SlipperyEast, SlipperyWest, SlipperyNorth, SlipperyNorthWest, SlipperySouthEast, SlipperyNorthEast, SlipperySouthWest))
  26. class MiniGridEnv(gym.Env):
  27. """
  28. 2D grid world game environment
  29. """
  30. metadata = {
  31. "render_modes": ["human", "rgb_array"],
  32. "render_fps": 10,
  33. }
  34. def __init__(
  35. self,
  36. mission_space: MissionSpace,
  37. grid_size: int | None = None,
  38. width: int | None = None,
  39. height: int | None = None,
  40. max_steps: int = 100,
  41. see_through_walls: bool = False,
  42. agent_view_size: int = 7,
  43. render_mode: str | None = None,
  44. screen_size: int | None = 640,
  45. highlight: bool = False,
  46. tile_size: int = TILE_PIXELS,
  47. agent_pov: bool = False,
  48. **kwargs
  49. ):
  50. # Initialize mission
  51. self.mission = mission_space.sample()
  52. # Can't set both grid_size and width/height
  53. if grid_size:
  54. assert width is None and height is None
  55. width = grid_size
  56. height = grid_size
  57. assert width is not None and height is not None
  58. # Action enumeration for this environment
  59. self.actions = Actions
  60. # Actions are discrete integer values
  61. self.action_space = spaces.Discrete(len(self.actions))
  62. # Number of cells (width and height) in the agent view
  63. assert agent_view_size % 2 == 1
  64. assert agent_view_size >= 3
  65. self.agent_view_size = agent_view_size
  66. # Observations are dictionaries containing an
  67. # encoding of the grid and a textual 'mission' string
  68. image_observation_space = spaces.Box(
  69. low=0,
  70. high=255,
  71. shape=(self.agent_view_size, self.agent_view_size, 3),
  72. dtype="uint8",
  73. )
  74. self.observation_space = spaces.Dict(
  75. {
  76. "image": image_observation_space,
  77. "direction": spaces.Discrete(4),
  78. "mission": mission_space,
  79. }
  80. )
  81. # Range of possible rewards
  82. self.reward_range = (0, 1)
  83. self.screen_size = screen_size
  84. self.render_size = None
  85. self.window = None
  86. self.clock = None
  87. # Environment configuration
  88. self.width = width
  89. self.height = height
  90. assert isinstance(
  91. max_steps, int
  92. ), f"The argument max_steps must be an integer, got: {type(max_steps)}"
  93. self.max_steps = max_steps
  94. self.see_through_walls = see_through_walls
  95. # Current position and direction of the agent
  96. self.agent_pos: np.ndarray | tuple[int, int] = None
  97. self.agent_dir: int = None
  98. # Current grid and mission and carrying
  99. self.grid = Grid(width, height)
  100. self.carrying = None
  101. self.objects = list()
  102. self.doors = list()
  103. # dict of adversaries
  104. self.adversaries = dict()
  105. # Rendering attributes
  106. self.render_mode = render_mode
  107. self.highlight = highlight
  108. self.tile_size = tile_size
  109. self.agent_pov = agent_pov
  110. # Custom
  111. self.background_tiles = dict()
  112. def reset(
  113. self,
  114. *,
  115. seed: int | None = None,
  116. options: dict[str, Any] | None = None,
  117. ) -> tuple[ObsType, dict[str, Any]]:
  118. super().reset(seed=seed)
  119. # Reinitialize episode-specific variables
  120. self.agent_pos = (-1, -1)
  121. self.agent_dir = -1
  122. self.goal_pos = (-1, -1)
  123. # Generate a new random grid at the start of each episode
  124. self.objects.clear()
  125. self.doors.clear()
  126. self._gen_grid(self.width, self.height)
  127. # These fields should be defined by _gen_grid
  128. assert (
  129. self.agent_pos >= (0, 0)
  130. if isinstance(self.agent_pos, tuple)
  131. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  132. )
  133. # Check that the agent doesn't overlap with an object
  134. start_cell = self.grid.get(*self.agent_pos)
  135. assert start_cell is None or start_cell.can_overlap()
  136. # Item picked up, being carried, initially nothing
  137. self.carrying = None
  138. # Step count since episode start
  139. self.step_count = 0
  140. if self.render_mode == "human":
  141. self.render()
  142. # Return first observation
  143. obs = self.gen_obs()
  144. return obs, {}
  145. def hash(self, size=16):
  146. """Compute a hash that uniquely identifies the current state of the environment.
  147. :param size: Size of the hashing
  148. """
  149. sample_hash = hashlib.sha256()
  150. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  151. to_encode += [(adv.adversary_pos, adv.adversary_dir, adv.color) for adv in self.adversaries]
  152. for item in to_encode:
  153. sample_hash.update(str(item).encode("utf8"))
  154. return sample_hash.hexdigest()[:size]
  155. def add_adversary(
  156. self,
  157. i: int,
  158. j: int,
  159. color: str,
  160. direction: int = 0,
  161. tasks: List[Task] = [DoRandom()],
  162. repeating=False
  163. ):
  164. """
  165. Adds an adversary to the grid
  166. """
  167. adv = Adversary((i,j), direction, color, tasks=tasks, repeating=repeating)
  168. self.adversaries[color] = adv
  169. return adv
  170. @property
  171. def steps_remaining(self):
  172. return self.max_steps - self.step_count
  173. def pprint_grid(self):
  174. """
  175. Produce a pretty string of the environment's grid along with the agent.
  176. A grid cell is represented by 2-character string, the first one for
  177. the object and the second one for the color.
  178. """
  179. if self.agent_pos is None or self.agent_dir is None or self.grid is None:
  180. raise ValueError(
  181. "The environment hasn't been `reset` therefore the `agent_pos`, `agent_dir` or `grid` are unknown."
  182. )
  183. # Map of object types to short string
  184. # Map agent's direction to short string
  185. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  186. output = ""
  187. for j in range(self.grid.height):
  188. for i in range(self.grid.width):
  189. if i == self.agent_pos[0] and j == self.agent_pos[1]:
  190. output += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  191. continue
  192. tile = self.grid.get(i, j)
  193. if tile is None:
  194. output += " "
  195. continue
  196. if tile.type == "door":
  197. if tile.is_open:
  198. output += "__"
  199. elif tile.is_locked:
  200. output += "L" + tile.color[0].upper()
  201. else:
  202. output += "D" + tile.color[0].upper()
  203. continue
  204. output += OBJECT_TO_STR[tile.type] + tile.color[0].upper()
  205. if j < self.grid.height - 1:
  206. output += "\n"
  207. return output
  208. def printGrid(self, init=False):
  209. """
  210. Produce a pretty string of the environment's grid along with the agent.
  211. A grid cell is represented by 2-character string, the first one for
  212. the object and the second one for the color.
  213. """
  214. if init:
  215. self._gen_grid(self.grid.width, self.grid.height) # todo need to add this for minigrid2prism
  216. #print("Dimensions: {} x {}".format(self.grid.height, self.grid.width))
  217. #self._gen_grid(self.grid.width, self.grid.height)
  218. # Map of object types to short string
  219. # Map agent's direction to short string
  220. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  221. str = ""
  222. background_str = ""
  223. adversaries = {adv.adversary_pos: adv for adv in self.adversaries.values()} if self.adversaries else {}
  224. bfs_rewards = []
  225. for j in range(self.grid.height):
  226. for i in range(self.grid.width):
  227. b = self.grid.get_background(i, j)
  228. c = self.grid.get(i, j)
  229. if (i,j) in adversaries.keys():
  230. a = adversaries[(i,j)]
  231. str += OBJECT_TO_STR["adversary"] + a.color[0].upper()
  232. if init:
  233. background_str += " "
  234. continue
  235. if init:
  236. if c and c.type == "wall":
  237. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  238. elif c and c.type in ["slipperynorth", "slipperyeast", "slipperysouth", "slipperywest", "slipperynorthwest", "slipperynortheast", "slipperysoutheast", "slipperysouthwest"]:
  239. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  240. elif b is None:
  241. background_str += " "
  242. else:
  243. if b.type != "floor":
  244. type_str = OBJECT_TO_STR[b.type]
  245. else:
  246. type_str = " "
  247. background_str += type_str + b.color.replace("light","")[0].upper()
  248. if hasattr(self, "bfs_reward") and self.bfs_reward:
  249. bfs_rewards.append(f"{i};{j};{self.bfs_reward[i + self.grid.width * j]}")
  250. if self.agent_pos is not None and i == self.agent_pos[0] and j == self.agent_pos[1]:
  251. if init:
  252. str += "XR"
  253. else:
  254. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  255. continue
  256. if c is None:
  257. str += " "
  258. continue
  259. if c.type == "door":
  260. if c.is_open:
  261. str += "__"
  262. elif c.is_locked:
  263. str += "L" + c.color[0].upper()
  264. else:
  265. str += "D" + c.color[0].upper()
  266. continue
  267. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  268. if j < self.grid.height - 1:
  269. str += "\n"
  270. if init:
  271. background_str += "\n"
  272. seperator = "-" * self.grid.width * 2
  273. if init and hasattr(self, "bfs_reward") and self.bfs_reward:
  274. return str + "\n" + seperator + "\n" + background_str + "\n" + seperator + "\n" + ";".join(bfs_rewards) + "\n" + seperator + "\n"
  275. else:
  276. return str + "\n" + seperator + "\n" + background_str + "\n" + seperator + "\n" + seperator + "\n"
  277. def export_grid(self, filename="grid.txt"):
  278. with open(filename, "w") as gridFile:
  279. gridFile.write(self.printGrid(init=True))
  280. @abstractmethod
  281. def _gen_grid(self, width, height):
  282. pass
  283. def _reward(self) -> float:
  284. """
  285. Compute the reward to be given upon success
  286. """
  287. return 1 - 0.9 * (self.step_count / self.max_steps)
  288. def _rand_int(self, low: int, high: int) -> int:
  289. """
  290. Generate random integer in [low,high[
  291. """
  292. return self.np_random.integers(low, high)
  293. def _rand_float(self, low: float, high: float) -> float:
  294. """
  295. Generate random float in [low,high[
  296. """
  297. return self.np_random.uniform(low, high)
  298. def _rand_bool(self) -> bool:
  299. """
  300. Generate random boolean value
  301. """
  302. return self.np_random.integers(0, 2) == 0
  303. def _rand_elem(self, iterable: Iterable[T]) -> T:
  304. """
  305. Pick a random element in a list
  306. """
  307. lst = list(iterable)
  308. idx = self._rand_int(0, len(lst))
  309. return lst[idx]
  310. def _rand_subset(self, iterable: Iterable[T], num_elems: int) -> list[T]:
  311. """
  312. Sample a random subset of distinct elements of a list
  313. """
  314. lst = list(iterable)
  315. assert num_elems <= len(lst)
  316. out: list[T] = []
  317. while len(out) < num_elems:
  318. elem = self._rand_elem(lst)
  319. lst.remove(elem)
  320. out.append(elem)
  321. return out
  322. def _rand_color(self) -> str:
  323. """
  324. Generate a random color name (string)
  325. """
  326. return self._rand_elem(COLOR_NAMES)
  327. def _rand_pos(
  328. self, x_low: int, x_high: int, y_low: int, y_high: int
  329. ) -> tuple[int, int]:
  330. """
  331. Generate a random (x,y) position tuple
  332. """
  333. return (
  334. self.np_random.integers(x_low, x_high),
  335. self.np_random.integers(y_low, y_high),
  336. )
  337. def place_obj(
  338. self,
  339. obj: WorldObj | None,
  340. top: Point = None,
  341. size: tuple[int, int] = None,
  342. reject_fn=None,
  343. max_tries=math.inf,
  344. ):
  345. """
  346. Place an object at an empty position in the grid
  347. :param top: top-left position of the rectangle where to place
  348. :param size: size of the rectangle where to place
  349. :param reject_fn: function to filter out potential positions
  350. """
  351. if top is None:
  352. top = (0, 0)
  353. else:
  354. top = (max(top[0], 0), max(top[1], 0))
  355. if size is None:
  356. size = (self.grid.width, self.grid.height)
  357. num_tries = 0
  358. while True:
  359. # This is to handle with rare cases where rejection sampling
  360. # gets stuck in an infinite loop
  361. if num_tries > max_tries:
  362. raise RecursionError("rejection sampling failed in place_obj")
  363. num_tries += 1
  364. pos = (
  365. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  366. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  367. )
  368. # Don't place the object on top of another object
  369. if self.grid.get(*pos) is not None:
  370. continue
  371. # Don't place the object where the agent is
  372. if np.array_equal(pos, self.agent_pos):
  373. continue
  374. # Check if there is a filtering criterion
  375. if reject_fn and reject_fn(self, pos):
  376. continue
  377. break
  378. self.grid.set(pos[0], pos[1], obj)
  379. if obj is not None:
  380. obj.init_pos = pos
  381. obj.cur_pos = pos
  382. return pos
  383. def put_obj(self, obj: WorldObj, i: int, j: int):
  384. """
  385. Put an object at a specific position in the grid
  386. """
  387. self.grid.set(i, j, obj)
  388. obj.init_pos = (i, j)
  389. obj.cur_pos = (i, j)
  390. if obj.can_pickup():
  391. self.objects.append(obj)
  392. self.objects = sorted(self.objects, key=lambda object: object.color)
  393. if obj.type == "door":
  394. self.doors.append(obj)
  395. self.doors = sorted(self.doors, key=lambda object: object.color)
  396. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  397. """
  398. Set the agent's starting point at an empty position in the grid
  399. """
  400. self.agent_pos = (-1, -1)
  401. pos = self.place_obj(None, top, size, max_tries=max_tries)
  402. self.agent_pos = pos
  403. if rand_dir:
  404. self.agent_dir = self._rand_int(0, 4)
  405. return pos
  406. def disable_random_start(self):
  407. pass
  408. def add_slippery_tile(self, i: int, j: int, type: str):
  409. """
  410. Adds a slippery tile to the grid
  411. """
  412. if type=="slipperynorth":
  413. slippery_tile = SlipperyNorth()
  414. elif type=="slipperysouth":
  415. slippery_tile = SlipperySouth()
  416. elif type=="slipperyeast":
  417. slippery_tile = SlipperyEast()
  418. elif type=="slipperywest":
  419. slippery_tile = SlipperyWest()
  420. else:
  421. slippery_tile = SlipperyNorth()
  422. self.grid.set(i, j, slippery_tile)
  423. return (i, j)
  424. @property
  425. def dir_vec(self):
  426. """
  427. Get the direction vector for the agent, pointing in the direction
  428. of forward movement.
  429. """
  430. assert (
  431. self.agent_dir >= 0 and self.agent_dir < 4
  432. ), f"Invalid agent_dir: {self.agent_dir} is not within range(0, 4)"
  433. return DIR_TO_VEC[self.agent_dir]
  434. @property
  435. def right_vec(self):
  436. """
  437. Get the vector pointing to the right of the agent.
  438. """
  439. dx, dy = self.dir_vec
  440. return np.array((-dy, dx))
  441. @property
  442. def front_pos(self):
  443. """
  444. Get the position of the cell that is right in front of the agent
  445. """
  446. return self.agent_pos + self.dir_vec
  447. def get_view_coords(self, i, j):
  448. """
  449. Translate and rotate absolute grid coordinates (i, j) into the
  450. agent's partially observable view (sub-grid). Note that the resulting
  451. coordinates may be negative or outside of the agent's view size.
  452. """
  453. ax, ay = self.agent_pos
  454. dx, dy = self.dir_vec
  455. rx, ry = self.right_vec
  456. # Compute the absolute coordinates of the top-left view corner
  457. sz = self.agent_view_size
  458. hs = self.agent_view_size // 2
  459. tx = ax + (dx * (sz - 1)) - (rx * hs)
  460. ty = ay + (dy * (sz - 1)) - (ry * hs)
  461. lx = i - tx
  462. ly = j - ty
  463. # Project the coordinates of the object relative to the top-left
  464. # corner onto the agent's own coordinate system
  465. vx = rx * lx + ry * ly
  466. vy = -(dx * lx + dy * ly)
  467. return vx, vy
  468. def get_view_exts(self, agent_view_size=None):
  469. """
  470. Get the extents of the square set of tiles visible to the agent
  471. Note: the bottom extent indices are not included in the set
  472. if agent_view_size is None, use self.agent_view_size
  473. """
  474. agent_view_size = agent_view_size or self.agent_view_size
  475. # Facing right
  476. if self.agent_dir == 0:
  477. topX = self.agent_pos[0]
  478. topY = self.agent_pos[1] - agent_view_size // 2
  479. # Facing down
  480. elif self.agent_dir == 1:
  481. topX = self.agent_pos[0] - agent_view_size // 2
  482. topY = self.agent_pos[1]
  483. # Facing left
  484. elif self.agent_dir == 2:
  485. topX = self.agent_pos[0] - agent_view_size + 1
  486. topY = self.agent_pos[1] - agent_view_size // 2
  487. # Facing up
  488. elif self.agent_dir == 3:
  489. topX = self.agent_pos[0] - agent_view_size // 2
  490. topY = self.agent_pos[1] - agent_view_size + 1
  491. else:
  492. assert False, "invalid agent direction"
  493. botX = topX + agent_view_size
  494. botY = topY + agent_view_size
  495. return topX, topY, botX, botY
  496. def relative_coords(self, x, y):
  497. """
  498. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  499. """
  500. vx, vy = self.get_view_coords(x, y)
  501. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  502. return None
  503. return vx, vy
  504. def get_neighbours(self, i, j):
  505. neighbours = list()
  506. potential_neighbours = [(i-1,j), (i,j+1), (i+1,j), (i,j-1)]
  507. for n in potential_neighbours:
  508. cell = self.grid.get(*n)
  509. if cell is None or (cell.can_overlap()): #and not isinstance(cell, Lava)):
  510. neighbours.append(n)
  511. return neighbours
  512. def run_BFS_reward(grid):
  513. if not hasattr(grid, "goal_pos") or np.all(grid.goal_pos == (-1, -1)):
  514. return []
  515. starting_position = (grid.goal_pos[0], grid.goal_pos[1])
  516. max_distance = 0
  517. distances = [None] * grid.width * grid.height
  518. bfs_queue = deque([starting_position])
  519. traversed_cells = set()
  520. distances[starting_position[0] + grid.width * starting_position[1]] = 0
  521. while bfs_queue:
  522. current_cell = bfs_queue.pop()
  523. if current_cell in traversed_cells: continue
  524. traversed_cells.add(current_cell)
  525. current_distance = distances[current_cell[0] + grid.width * current_cell[1]]
  526. if current_distance > max_distance:
  527. max_distance = current_distance
  528. for neighbour in grid.get_neighbours(*current_cell):
  529. if neighbour in traversed_cells:
  530. continue
  531. bfs_queue.appendleft(neighbour)
  532. if distances[neighbour[0] + grid.width * neighbour[1]] is None:
  533. distances[neighbour[0] + grid.width * neighbour[1]] = current_distance + 1
  534. distances = [x if x else 0 for x in distances]
  535. # return [ (-x/1) for x in distances]
  536. return [ (1/4)* (-x/max_distance) if x != 0 else 0 for x in distances]
  537. def print_bfs_reward(self):
  538. rep = ""
  539. for j in range(self.grid.height):
  540. for i in range(self.grid.width):
  541. rep += F"{self.bfs_reward[j * self.grid.height + i]:5.2f} "
  542. rep += '\n'
  543. print(rep)
  544. def in_view(self, x, y):
  545. """
  546. check if a grid position is visible to the agent
  547. """
  548. return self.relative_coords(x, y) is not None
  549. def agent_sees(self, x, y):
  550. """
  551. Check if a non-empty grid position is visible to the agent
  552. """
  553. coordinates = self.relative_coords(x, y)
  554. if coordinates is None:
  555. return False
  556. vx, vy = coordinates
  557. obs = self.gen_obs()
  558. obs_grid, _ = Grid.decode(obs["image"])
  559. obs_cell = obs_grid.get(vx, vy)
  560. world_cell = self.grid.get(x, y)
  561. assert world_cell is not None
  562. return obs_cell is not None and obs_cell.type == world_cell.type
  563. def step(
  564. self, action: ActType
  565. ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]:
  566. self.step_count += 1
  567. reward = 0
  568. terminated = False
  569. truncated = False
  570. info = dict()
  571. need_position_update = False
  572. # Get the position in front of the agent
  573. fwd_pos = self.front_pos
  574. # Get the contents of the cell in front of the agent
  575. fwd_cell = self.grid.get(*fwd_pos)
  576. current_cell = self.grid.get(*self.agent_pos)
  577. opened_door = False
  578. picked_up = False
  579. if action == self.actions.forward and is_slippery(current_cell):
  580. probabilities = current_cell.get_probabilities(self.agent_dir)
  581. possible_fwd_pos, prob = self.get_neighbours_prob(self.agent_pos, probabilities)
  582. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  583. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  584. fwd_cell = self.grid.get(*fwd_pos)
  585. need_position_update = True
  586. # Rotate left
  587. elif action == self.actions.left:
  588. if is_slippery(current_cell):
  589. possible_fwd_pos, prob = self.get_neighbours_prob(self.agent_pos, current_cell.probabilities_turn)
  590. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  591. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  592. fwd_cell = self.grid.get(*fwd_pos)
  593. if fwd_pos == (self.agent_pos[0], self.agent_pos[1]):
  594. self.agent_dir -= 1
  595. if self.agent_dir < 0:
  596. self.agent_dir += 4
  597. else:
  598. need_position_update = True
  599. else:
  600. self.agent_dir -= 1
  601. if self.agent_dir < 0:
  602. self.agent_dir += 4
  603. # Rotate right
  604. elif action == self.actions.right:
  605. if is_slippery(current_cell):
  606. possible_fwd_pos, prob = self.get_neighbours_prob(self.agent_pos, current_cell.probabilities_turn)
  607. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  608. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  609. fwd_cell = self.grid.get(*fwd_pos)
  610. if fwd_pos == (self.agent_pos[0], self.agent_pos[1]):
  611. self.agent_dir = (self.agent_dir + 1) % 4
  612. else:
  613. need_position_update = True
  614. else:
  615. self.agent_dir = (self.agent_dir + 1) % 4
  616. # Move forward
  617. elif action == self.actions.forward:
  618. if fwd_cell is None or fwd_cell.can_overlap():
  619. self.agent_pos = tuple(fwd_pos)
  620. fwd_cell = self.grid.get(*fwd_pos)
  621. need_position_update = True
  622. # Pick up an object
  623. elif action == self.actions.pickup:
  624. if fwd_cell and fwd_cell.can_pickup():
  625. if self.carrying is None:
  626. self.carrying = fwd_cell
  627. self.carrying.cur_pos = np.array([-1, -1])
  628. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  629. picked_up = True
  630. # Drop an object
  631. elif action == self.actions.drop:
  632. if not fwd_cell and self.carrying:
  633. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  634. self.carrying.cur_pos = fwd_pos
  635. self.carrying = None
  636. # Toggle/activate an object
  637. elif action == self.actions.toggle:
  638. if fwd_cell:
  639. fwd_cell.toggle(self, fwd_pos)
  640. if fwd_cell.type == "door" and fwd_cell.is_open:
  641. opened_door = True
  642. # Done action (not used by default)
  643. elif action == self.actions.done:
  644. pass
  645. else:
  646. raise ValueError(f"Unknown action: {action}")
  647. if need_position_update and (fwd_cell is None or fwd_cell.can_overlap()):
  648. self.agent_pos = tuple(fwd_pos)
  649. current_cell = self.grid.get(*self.agent_pos)
  650. collision = False
  651. if self.adversaries:
  652. for adversary in self.adversaries.values():
  653. if np.array_equal(self.agent_pos, adversary.adversary_pos):
  654. collision = True
  655. reached_goal = False
  656. ran_into_lava = False
  657. if current_cell is not None and current_cell.type == "goal":
  658. terminated = True
  659. reached_goal = True
  660. try: reward = self.goal_reward
  661. except: reward = 1
  662. elif current_cell is not None and current_cell.type == "lava":
  663. terminated = True
  664. ran_into_lava = True
  665. try: reward = self.failure_penalty
  666. except: reward = -1
  667. elif collision:
  668. terminated = True
  669. try: reward = self.collision_penalty
  670. except: reward = -1
  671. self.agent_pos = tuple(fwd_pos)
  672. else:
  673. try: reward += self.bfs_reward[self.agent_pos[0] + self.grid.width * self.agent_pos[1]]
  674. except: pass
  675. if self.step_count >= self.max_steps:
  676. truncated = True
  677. if self.render_mode == "human":
  678. self.render()
  679. info["reached_goal"] = reached_goal
  680. info["ran_into_lava"] = ran_into_lava
  681. info["opened_door"] = opened_door
  682. info["picked_up"] = picked_up
  683. #if terminated:
  684. # print(f"Terminated at: {self.agent_pos} {self.grid.get(*self.agent_pos)} {info}")
  685. if len(self.adversaries) > 0: info["collision"] = collision
  686. obs = self.gen_obs()
  687. return obs, reward, terminated, truncated, info
  688. def get_neighbours_prob(self, agent_pos, probabilities):
  689. neighbours = [tuple((x,y)) for x in range(agent_pos[0]-1, agent_pos[0]+2) for y in range(agent_pos[1]-1,agent_pos[1]+2)]
  690. probabilities_dict = dict(zip(neighbours, probabilities))
  691. for pos in probabilities_dict:
  692. cell = self.grid.get(*pos)
  693. if cell is not None and not cell.can_overlap():
  694. probabilities_dict[pos] = 0.0
  695. try:
  696. return list(probabilities_dict.keys()), [float(p) / sum(probabilities_dict.values()) for p in probabilities_dict.values()]
  697. except ZeroDivisionError as e:
  698. return list(probabilities_dict.keys()), stay_at_pos_distribution
  699. def gen_obs_grid(self, agent_view_size=None):
  700. """
  701. Generate the sub-grid observed by the agent.
  702. This method also outputs a visibility mask telling us which grid
  703. cells the agent can actually see.
  704. if agent_view_size is None, self.agent_view_size is used
  705. """
  706. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  707. agent_view_size = agent_view_size or self.agent_view_size
  708. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  709. for i in range(self.agent_dir + 1):
  710. grid = grid.rotate_left()
  711. # Process occluders and visibility
  712. # Note that this incurs some performance cost
  713. if not self.see_through_walls:
  714. vis_mask = grid.process_vis(
  715. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  716. )
  717. else:
  718. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  719. # Make it so the agent sees what it's carrying
  720. # We do this by placing the carried object at the agent's position
  721. # in the agent's partially observable view
  722. agent_pos = grid.width // 2, grid.height - 1
  723. if self.carrying:
  724. grid.set(*agent_pos, self.carrying)
  725. else:
  726. grid.set(*agent_pos, None)
  727. return grid, vis_mask
  728. def gen_obs(self):
  729. """
  730. Generate the agent's view (partially observable, low-resolution encoding)
  731. """
  732. grid, vis_mask = self.gen_obs_grid()
  733. # Encode the partially observable view into a numpy array
  734. image = grid.encode(vis_mask)
  735. # Observations are dictionaries containing:
  736. # - an image (partially observable view of the environment)
  737. # - the agent's direction/orientation (acting as a compass)
  738. # - a textual mission string (instructions for the agent)
  739. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  740. return obs
  741. def get_pov_render(self, tile_size):
  742. """
  743. Render an agent's POV observation for visualization
  744. """
  745. grid, vis_mask = self.gen_obs_grid()
  746. # Render the whole grid
  747. img = grid.render(
  748. tile_size,
  749. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  750. agent_dir=3,
  751. adversaries=self.adversaries.values(),
  752. highlight_mask=vis_mask,
  753. )
  754. return img
  755. def get_full_render(self, highlight, tile_size):
  756. """
  757. Render a non-paratial observation for visualization
  758. """
  759. # Compute which cells are visible to the agent
  760. _, vis_mask = self.gen_obs_grid()
  761. # Compute the world coordinates of the bottom-left corner
  762. # of the agent's view area
  763. f_vec = self.dir_vec
  764. r_vec = self.right_vec
  765. top_left = (
  766. self.agent_pos
  767. + f_vec * (self.agent_view_size - 1)
  768. - r_vec * (self.agent_view_size // 2)
  769. )
  770. # Mask of which cells to highlight
  771. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  772. # For each cell in the visibility mask
  773. for vis_j in range(0, self.agent_view_size):
  774. for vis_i in range(0, self.agent_view_size):
  775. # If this cell is not visible, don't highlight it
  776. if not vis_mask[vis_i, vis_j]:
  777. continue
  778. # Compute the world coordinates of this cell
  779. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  780. if abs_i < 0 or abs_i >= self.width:
  781. continue
  782. if abs_j < 0 or abs_j >= self.height:
  783. continue
  784. # Mark this cell to be highlighted
  785. highlight_mask[abs_i, abs_j] = True
  786. # Render the whole grid
  787. img = self.grid.render(
  788. tile_size,
  789. self.agent_pos,
  790. self.agent_dir,
  791. adversaries=self.adversaries.values() if self.adversaries else [],
  792. highlight_mask=highlight_mask if highlight else None,
  793. )
  794. return img
  795. def get_frame(
  796. self,
  797. highlight: bool = True,
  798. tile_size: int = TILE_PIXELS,
  799. agent_pov: bool = False,
  800. ):
  801. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  802. Args:
  803. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  804. tile_size (int): How many pixels will form a tile from the NxM grid.
  805. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  806. Returns:
  807. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  808. """
  809. if agent_pov:
  810. return self.get_pov_render(tile_size)
  811. else:
  812. return self.get_full_render(highlight, tile_size)
  813. def render(self):
  814. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  815. screen_width = 2 * self.tile_size * self.grid.width
  816. screen_height = 2 * self.tile_size * self.grid.height
  817. if self.render_mode == "human":
  818. img = np.transpose(img, axes=(1, 0, 2))
  819. if self.render_size is None:
  820. self.render_size = img.shape[:2]
  821. if self.window is None:
  822. pygame.init()
  823. pygame.display.init()
  824. self.window = pygame.display.set_mode(
  825. (screen_width, screen_height)
  826. )
  827. pygame.display.set_caption("minigrid")
  828. if self.clock is None:
  829. self.clock = pygame.time.Clock()
  830. surf = pygame.surfarray.make_surface(img)
  831. # Create background with mission description
  832. offset = surf.get_size()[0] * 0.1
  833. offset = 0
  834. # offset = 32 if self.agent_pov else 64
  835. bg = pygame.Surface(
  836. (int(surf.get_size()[0] + offset), int(surf.get_size()[1] + offset))
  837. )
  838. bg.convert()
  839. bg.fill((255, 255, 255))
  840. bg.blit(surf, (offset / 2, 0))
  841. bg = pygame.transform.smoothscale(bg, (screen_width, screen_height))
  842. #font_size = 22
  843. #text = self.mission
  844. #font = pygame.freetype.SysFont(pygame.font.get_default_font(), font_size)
  845. #text_rect = font.get_rect(text, size=font_size)
  846. #text_rect.center = bg.get_rect().center
  847. #text_rect.y = bg.get_height() - font_size * 1.5
  848. #font.render_to(bg, text_rect, text, size=font_size)
  849. self.window.blit(bg, (0, 0))
  850. pygame.event.pump()
  851. self.clock.tick(self.metadata["render_fps"])
  852. pygame.display.flip()
  853. elif self.render_mode == "rgb_array":
  854. return img
  855. def get_symbolic_state(self):
  856. adversaries = tuple()
  857. balls = tuple()
  858. keys = tuple()
  859. boxes = tuple()
  860. doors = tuple()
  861. for obj in self.objects:
  862. if obj.type == "box":
  863. boxes += (obj.to_state(),)
  864. if obj.type == "ball":
  865. balls += (obj.to_state(),)
  866. if obj.type == "key":
  867. keys += (obj.to_state(),)
  868. for door in self.doors:
  869. doors += (door.to_state(),)
  870. for color in COLOR_NAMES:
  871. try:
  872. adversaries += (self.adversaries[color].to_state(),)
  873. except Exception as e:
  874. pass
  875. carrying = "" if not self.carrying else f"{self.carrying.color.capitalize()}{self.carrying.type.capitalize()}"
  876. state = State(colAgent=self.agent_pos[0], rowAgent=self.agent_pos[1], viewAgent=self.agent_dir, carrying=carrying, adversaries=adversaries, keys=keys, doors=doors)
  877. return state
  878. def close(self):
  879. if self.window:
  880. pygame.quit()