You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2102 lines
68 KiB

1 year ago
  1. import hashlib
  2. import math
  3. from abc import abstractmethod
  4. from enum import IntEnum
  5. from typing import Any, Callable, Optional, Union, List
  6. import gym
  7. import numpy
  8. import numpy as np
  9. from gym import spaces
  10. from gym.utils import seeding
  11. from collections import deque
  12. from copy import deepcopy
  13. import colorsys
  14. TASKREWARD = 20
  15. PICKUPREWARD = TASKREWARD
  16. DELIVERREWARD = TASKREWARD
  17. # Size in pixels of a tile in the full-scale human view
  18. from gym_minigrid.rendering import (
  19. downsample,
  20. fill_coords,
  21. highlight_img,
  22. point_in_circle,
  23. point_in_line,
  24. point_in_rect,
  25. point_in_triangle,
  26. rotate_fn,
  27. )
  28. from gym_minigrid.window import Window
  29. #from gym_minigrid.Task import DoRandom, TaskManager, DoNothing, GoTo, PlaceObject, PickUpObject, Task
  30. TILE_PIXELS = 32
  31. # Map of color names to RGB values
  32. COLORS = {
  33. "red": np.array([255, 0, 0]),
  34. "lightred": np.array([255, 165, 165]),
  35. "green": np.array([0, 255, 0]),
  36. "lightgreen": np.array([165, 255, 165]),
  37. "blue": np.array([0, 0, 255]),
  38. "lightblue": np.array([165, 165, 255]),
  39. "purple": np.array([112, 39, 195]),
  40. "lightpurple": np.array([202, 170, 238]),
  41. "yellow": np.array([255, 255, 0]),
  42. "grey": np.array([100, 100, 100]),
  43. }
  44. COLOR_NAMES = sorted(list(COLORS.keys()))
  45. # Used to map colors to integers
  46. COLOR_TO_IDX = {"red": 0, "green": 1, "blue": 2, "purple": 3, "yellow": 4, "grey": 5, "lightred": 6, "lightblue":7, "lightgreen":8, "lightpurple": 9}
  47. IDX_TO_COLOR = dict(zip(COLOR_TO_IDX.values(), COLOR_TO_IDX.keys()))
  48. def color_bins(nr_bins):
  49. color_1 = np.array([222,237,182])
  50. color_2 = np.array([45,78,157])
  51. inter_values = [1./nr_bins * i for i in range(nr_bins+1)]
  52. color_values = []
  53. for value in inter_values:
  54. color_values.append((color_2 - color_1) * value + color_1)
  55. return color_values
  56. def color_states(value, ranking, nr_bins):
  57. set_size = len(ranking)
  58. bin_size = math.floor(set_size/nr_bins)
  59. bin_bound_values = []
  60. for i in range(nr_bins):
  61. bin_bound_values.append(ranking[bin_size*i])
  62. bin_bound_values.append(1.)
  63. if len(set(bin_bound_values))!= len(bin_bound_values):
  64. print("at least two bounds have the same value, you should probably reduce the bin size")
  65. color_values = color_bins(nr_bins)
  66. for i in range(nr_bins):
  67. if(value >= bin_bound_values[i] and value <=bin_bound_values[i+1]):
  68. if bin_bound_values[i]==bin_bound_values[i+1]:
  69. inter_value = 0
  70. else:
  71. inter_value = (value - bin_bound_values[i])/(bin_bound_values[i+1] - bin_bound_values[i])
  72. return (color_values[i+1] - color_values[i]) * inter_value + color_values[i]
  73. def isSlippery(cell):
  74. if isinstance(cell, SlipperyNorth):
  75. return True
  76. elif isinstance(cell, SlipperySouth):
  77. return True
  78. elif isinstance(cell, SlipperyEast):
  79. return True
  80. elif isinstance(cell, SlipperyWest):
  81. return True
  82. else:
  83. return False
  84. # Map of object type to integers
  85. OBJECT_TO_IDX = {
  86. "unseen": 0,
  87. "empty": 1,
  88. "wall": 2,
  89. "floor": 3,
  90. "door": 4,
  91. "key": 5,
  92. "ball": 6,
  93. "box": 7,
  94. "goal": 8,
  95. "lava": 9,
  96. "agent": 10,
  97. "adversary": 11,
  98. "slipperynorth": 12,
  99. "slipperysouth": 13,
  100. "slipperyeast": 14,
  101. "slipperywest": 15,
  102. "heattile" : 16,
  103. "heattilereduced" : 17
  104. }
  105. IDX_TO_OBJECT = dict(zip(OBJECT_TO_IDX.values(), OBJECT_TO_IDX.keys()))
  106. # Map of state names to integers
  107. STATE_TO_IDX = {
  108. "open": 0,
  109. "closed": 1,
  110. "locked": 2,
  111. }
  112. # Map of agent direction indices to vectors
  113. DIR_TO_VEC = [
  114. # Pointing right (positive X)
  115. np.array((1, 0)),
  116. # Down (positive Y)
  117. np.array((0, 1)),
  118. # Pointing left (negative X)
  119. np.array((-1, 0)),
  120. # Up (negative Y)
  121. np.array((0, -1)),
  122. ]
  123. def check_if_no_duplicate(duplicate_list: list) -> bool:
  124. """Check if given list contains any duplicates"""
  125. return len(set(duplicate_list)) == len(duplicate_list)
  126. def run_BFS_reward(grid, starting_position):
  127. max_distance = 0
  128. distances = [None] * grid.width * grid.height
  129. bfs_queue = deque([starting_position])
  130. traversed_cells = set()
  131. distances[starting_position[0] + grid.width * starting_position[1]] = 0
  132. while bfs_queue:
  133. current_cell = bfs_queue.pop()
  134. traversed_cells.add(current_cell)
  135. current_distance = distances[current_cell[0] + grid.width * current_cell[1]]
  136. if current_distance > max_distance:
  137. max_distance = current_distance
  138. for neighbour in grid.get_neighbours(*current_cell):
  139. if neighbour in traversed_cells:
  140. continue
  141. bfs_queue.appendleft(neighbour)
  142. if distances[neighbour[0] + grid.width * neighbour[1]] is None:
  143. distances[neighbour[0] + grid.width * neighbour[1]] = current_distance + 1
  144. distances = [x if x else 0 for x in distances]
  145. #for i, x in enumerate(distances):
  146. # if i % grid.width == 0:
  147. # print("")
  148. # if i is None:
  149. # print(" ")
  150. # continue
  151. # else:
  152. # print("{:0>4},".format(x), end="")
  153. #print("")
  154. return [-x/max_distance for x in distances]
  155. class MissionSpace(spaces.Space[str]):
  156. r"""A space representing a mission for the Gym-Minigrid environments.
  157. The space allows generating random mission strings constructed with an input placeholder list.
  158. Example Usage::
  159. >>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
  160. ordered_placeholders=[["green", "blue"]])
  161. >>> observation_space.sample()
  162. "Get the green ball."
  163. >>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.".,
  164. ordered_placeholders=None)
  165. >>> observation_space.sample()
  166. "Get the ball."
  167. """
  168. def __init__(
  169. self,
  170. mission_func: Callable[..., str],
  171. ordered_placeholders: Optional["list[list[str]]"] = None,
  172. seed: Optional[Union[int, seeding.RandomNumberGenerator]] = None,
  173. ):
  174. r"""Constructor of :class:`MissionSpace` space.
  175. Args:
  176. mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
  177. ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
  178. seed: seed: The seed for sampling from the space.
  179. """
  180. # Check that the ordered placeholders and mission function are well defined.
  181. if ordered_placeholders is not None:
  182. assert (
  183. len(ordered_placeholders) == mission_func.__code__.co_argcount
  184. ), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
  185. for placeholder_list in ordered_placeholders:
  186. assert check_if_no_duplicate(
  187. placeholder_list
  188. ), "Make sure that the placeholders don't have any duplicate values."
  189. else:
  190. assert (
  191. mission_func.__code__.co_argcount == 0
  192. ), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
  193. self.ordered_placeholders = ordered_placeholders
  194. self.mission_func = mission_func
  195. super().__init__(dtype=str, seed=seed)
  196. # Check that mission_func returns a string
  197. sampled_mission = self.sample()
  198. assert isinstance(
  199. sampled_mission, str
  200. ), f"mission_func must return type str not {type(sampled_mission)}"
  201. def sample(self) -> str:
  202. """Sample a random mission string."""
  203. if self.ordered_placeholders is not None:
  204. placeholders = []
  205. for rand_var_list in self.ordered_placeholders:
  206. idx = self.np_random.integers(0, len(rand_var_list))
  207. placeholders.append(rand_var_list[idx])
  208. return self.mission_func(*placeholders)
  209. else:
  210. return self.mission_func()
  211. def contains(self, x: Any) -> bool:
  212. """Return boolean specifying if x is a valid member of this space."""
  213. # Store a list of all the placeholders from self.ordered_placeholders that appear in x
  214. if self.ordered_placeholders is not None:
  215. check_placeholder_list = []
  216. for placeholder_list in self.ordered_placeholders:
  217. for placeholder in placeholder_list:
  218. if placeholder in x:
  219. check_placeholder_list.append(placeholder)
  220. # Remove duplicates from the list
  221. check_placeholder_list = list(set(check_placeholder_list))
  222. start_id_placeholder = []
  223. end_id_placeholder = []
  224. # Get the starting and ending id of the identified placeholders with possible duplicates
  225. new_check_placeholder_list = []
  226. for placeholder in check_placeholder_list:
  227. new_start_id_placeholder = [
  228. i for i in range(len(x)) if x.startswith(placeholder, i)
  229. ]
  230. new_check_placeholder_list += [placeholder] * len(
  231. new_start_id_placeholder
  232. )
  233. end_id_placeholder += [
  234. start_id + len(placeholder) - 1
  235. for start_id in new_start_id_placeholder
  236. ]
  237. start_id_placeholder += new_start_id_placeholder
  238. # Order by starting id the placeholders
  239. ordered_placeholder_list = sorted(
  240. zip(
  241. start_id_placeholder, end_id_placeholder, new_check_placeholder_list
  242. )
  243. )
  244. # Check for repeated placeholders contained in each other
  245. remove_placeholder_id = []
  246. for i, placeholder_1 in enumerate(ordered_placeholder_list):
  247. starting_id = i + 1
  248. for j, placeholder_2 in enumerate(
  249. ordered_placeholder_list[starting_id:]
  250. ):
  251. # Check if place holder ids overlap and keep the longest
  252. if max(placeholder_1[0], placeholder_2[0]) < min(
  253. placeholder_1[1], placeholder_2[1]
  254. ):
  255. remove_placeholder = min(
  256. placeholder_1[2], placeholder_2[2], key=len
  257. )
  258. if remove_placeholder == placeholder_1[2]:
  259. remove_placeholder_id.append(i)
  260. else:
  261. remove_placeholder_id.append(i + j + 1)
  262. for id in remove_placeholder_id:
  263. del ordered_placeholder_list[id]
  264. final_placeholders = [
  265. placeholder[2] for placeholder in ordered_placeholder_list
  266. ]
  267. # Check that the identified final placeholders are in the same order as the original placeholders.
  268. for orered_placeholder, final_placeholder in zip(
  269. self.ordered_placeholders, final_placeholders
  270. ):
  271. if final_placeholder in orered_placeholder:
  272. continue
  273. else:
  274. return False
  275. try:
  276. mission_string_with_placeholders = self.mission_func(
  277. *final_placeholders
  278. )
  279. except Exception as e:
  280. print(
  281. f"{x} is not contained in MissionSpace due to the following exception: {e}"
  282. )
  283. return False
  284. return bool(mission_string_with_placeholders == x)
  285. else:
  286. return bool(self.mission_func() == x)
  287. def __repr__(self) -> str:
  288. """Gives a string representation of this space."""
  289. return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
  290. def __eq__(self, other) -> bool:
  291. """Check whether ``other`` is equivalent to this instance."""
  292. if isinstance(other, MissionSpace):
  293. # Check that place holder lists are the same
  294. if self.ordered_placeholders is not None:
  295. # Check length
  296. if (len(self.order_placeholder) == len(other.order_placeholder)) and (
  297. all(
  298. set(i) == set(j)
  299. for i, j in zip(self.order_placeholder, other.order_placeholder)
  300. )
  301. ):
  302. # Check mission string is the same with dummy space placeholders
  303. test_placeholders = [""] * len(self.order_placeholder)
  304. mission = self.mission_func(*test_placeholders)
  305. other_mission = other.mission_func(*test_placeholders)
  306. return mission == other_mission
  307. else:
  308. # Check that other is also None
  309. if other.ordered_placeholders is None:
  310. # Check mission string is the same
  311. mission = self.mission_func()
  312. other_mission = other.mission_func()
  313. return mission == other_mission
  314. # If none of the statements above return then False
  315. return False
  316. class WorldObj:
  317. """
  318. Base class for grid world objects
  319. """
  320. def __init__(self, type, color):
  321. assert type in OBJECT_TO_IDX, type
  322. assert color in COLOR_TO_IDX, color
  323. self.type = type
  324. self.color = color
  325. self.contains = None
  326. # Initial position of the object
  327. self.init_pos = None
  328. # Current position of the object
  329. self.cur_pos = None
  330. def can_overlap(self):
  331. """Can the agent overlap with this?"""
  332. return False
  333. def can_pickup(self):
  334. """Can the agent pick this up?"""
  335. return False
  336. def can_contain(self):
  337. """Can this contain another object?"""
  338. return False
  339. def see_behind(self):
  340. """Can the agent see behind this object?"""
  341. return True
  342. def toggle(self, env, pos):
  343. """Method to trigger/toggle an action this object performs"""
  344. return False
  345. def encode(self):
  346. """Encode the a description of this object as a 3-tuple of integers"""
  347. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  348. @staticmethod
  349. def decode(type_idx, color_idx, state):
  350. """Create an object from a 3-tuple state description"""
  351. obj_type = IDX_TO_OBJECT[type_idx]
  352. color = IDX_TO_COLOR[color_idx]
  353. if obj_type == "empty" or obj_type == "unseen":
  354. return None
  355. # State, 0: open, 1: closed, 2: locked
  356. is_open = state == 0
  357. is_locked = state == 2
  358. if obj_type == "wall":
  359. v = Wall(color)
  360. elif obj_type == "floor":
  361. v = Floor(color)
  362. elif obj_type == "ball":
  363. v = Ball(color)
  364. elif obj_type == "key":
  365. v = Key(color)
  366. elif obj_type == "box":
  367. v = Box(color)
  368. elif obj_type == "door":
  369. v = Door(color, is_open, is_locked)
  370. elif obj_type == "goal":
  371. v = Goal()
  372. elif obj_type == "lava":
  373. v = Lava()
  374. elif obj_type == "slipperynorth":
  375. v = SlipperyNorth(color)
  376. elif obj_type == "slipperysouth":
  377. v = SlipperySouth(color)
  378. elif obj_type == "slipperywest":
  379. v = SlipperyWest(color)
  380. elif obj_type == "slipperyeast":
  381. v = SlipperyEast(color)
  382. elif obj_type == "heattile":
  383. v = HeatMapTile(color)
  384. elif obj_type == "heattilereduced":
  385. v = HeatMapTileReduced(color)
  386. else:
  387. assert False, "unknown object type in decode '%s'" % obj_type
  388. return v
  389. def render(self, r):
  390. """Draw this object with the given renderer"""
  391. raise NotImplementedError
  392. class Goal(WorldObj):
  393. def __init__(self):
  394. super().__init__("goal", "green")
  395. def can_overlap(self):
  396. return True
  397. def render(self, img):
  398. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  399. class Floor(WorldObj):
  400. """
  401. Colored floor tile the agent can walk over
  402. """
  403. def __init__(self, color="blue"):
  404. super().__init__("floor", color)
  405. def can_overlap(self):
  406. return True
  407. def render(self, img):
  408. # Give the floor a pale color
  409. color = COLORS[self.color] / 2
  410. fill_coords(img, point_in_rect(0.031, 1, 0.031, 1), color)
  411. def can_contain(self):
  412. return True
  413. class Lava(WorldObj):
  414. def __init__(self):
  415. super().__init__("lava", "red")
  416. def can_overlap(self):
  417. return True
  418. def render(self, img):
  419. c = (255, 128, 0)
  420. # Background color
  421. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  422. # Little waves
  423. for i in range(3):
  424. ylo = 0.3 + 0.2 * i
  425. yhi = 0.4 + 0.2 * i
  426. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  427. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  428. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  429. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  430. class Wall(WorldObj):
  431. def __init__(self, color="grey"):
  432. super().__init__("wall", color)
  433. def see_behind(self):
  434. return False
  435. def render(self, img):
  436. fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
  437. class Door(WorldObj):
  438. def __init__(self, color, is_open=False, is_locked=False):
  439. super().__init__("door", color)
  440. self.is_open = is_open
  441. self.is_locked = is_locked
  442. def can_overlap(self):
  443. """The agent can only walk over this cell when the door is open"""
  444. return self.is_open
  445. def see_behind(self):
  446. return self.is_open
  447. def toggle(self, env, pos):
  448. # If the player has the right key to open the door
  449. if self.is_locked:
  450. if isinstance(env.carrying, Key) and env.carrying.color == self.color:
  451. self.is_locked = False
  452. self.is_open = True
  453. return True
  454. return False
  455. self.is_open = not self.is_open
  456. return True
  457. def encode(self):
  458. """Encode the a description of this object as a 3-tuple of integers"""
  459. # State, 0: open, 1: closed, 2: locked
  460. if self.is_open:
  461. state = 0
  462. elif self.is_locked:
  463. state = 2
  464. # if door is closed and unlocked
  465. elif not self.is_open:
  466. state = 1
  467. else:
  468. raise ValueError(
  469. f"There is no possible state encoding for the state:\n -Door Open: {self.is_open}\n -Door Closed: {not self.is_open}\n -Door Locked: {self.is_locked}"
  470. )
  471. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], state)
  472. def render(self, img):
  473. c = COLORS[self.color]
  474. if self.is_open:
  475. fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
  476. fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
  477. return
  478. # Door frame and door
  479. if self.is_locked:
  480. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  481. fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
  482. # Draw key slot
  483. fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
  484. else:
  485. fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
  486. fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
  487. fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
  488. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
  489. # Draw door handle
  490. fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
  491. class Key(WorldObj):
  492. def __init__(self, color="blue"):
  493. super().__init__("key", color)
  494. def can_pickup(self):
  495. return True
  496. def render(self, img):
  497. c = COLORS[self.color]
  498. # Vertical quad
  499. fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
  500. # Teeth
  501. fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
  502. fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
  503. # Ring
  504. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
  505. fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
  506. class Ball(WorldObj):
  507. def __init__(self, color="blue"):
  508. super().__init__("ball", color)
  509. def can_pickup(self):
  510. return True
  511. def render(self, img):
  512. fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
  513. class Box(WorldObj):
  514. def __init__(self, color, contains=None):
  515. super().__init__("box", color)
  516. self.contains = contains
  517. self.picked_up = False
  518. self.placed_at_destination = False
  519. def can_pickup(self):
  520. return not self.placed_at_destination
  521. def can_overlap(self):
  522. return True
  523. def render(self, img):
  524. c = COLORS[self.color]
  525. # Outline
  526. fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
  527. fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
  528. # Horizontal slit
  529. fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
  530. def toggle(self, env, pos):
  531. # Replace the box by its contents
  532. return False # We dont want to destroy boxes, This is local change
  533. #env.grid.set(pos[0], pos[1], self.contains)
  534. #return True
  535. class SlipperyNorth(WorldObj):
  536. def __init__(self, color: str = "blue"):
  537. super().__init__("slipperynorth", color)
  538. self.probabilities_forward = [0.0, 1./9, 2./9, 0.0, -50, -50, 0.0, 1./9, 2./9]
  539. self.probabilities_turn = [0.0, 0.0, 1./9, 0.0, -50, 1./9, 0.0, 0.0, 1./9]
  540. self.offset = (0,1)
  541. def can_overlap(self):
  542. return True
  543. def render(self, img):
  544. c = (100, 100, 200)
  545. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  546. for i in range(6):
  547. ylo = 0.1 + 0.15 * i
  548. yhi = 0.20 + 0.15 * i
  549. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  550. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  551. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  552. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  553. class SlipperySouth(WorldObj):
  554. def __init__(self, color: str = "blue"):
  555. super().__init__("slipperysouth", color)
  556. self.probabilities_forward = [2./9, 1./9, 0.0, -50, -50, 0.0, 2./9, 1./9, 0.0]
  557. self.probabilities_turn = [1./9, 0.0, 0.0, 1./9, -50, 0.0, 1./9, 0.0, 0.0]
  558. self.offset = (0,-1)
  559. def can_overlap(self):
  560. return True
  561. def render(self, img):
  562. c = (100, 100, 200)
  563. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  564. for i in range(6):
  565. ylo = 0.1 + 0.15 * i
  566. yhi = 0.20 + 0.15 * i
  567. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  568. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  569. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  570. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  571. class SlipperyEast(WorldObj):
  572. def __init__(self, color: str = "blue"):
  573. super().__init__("slipperyeast", color)
  574. self.probabilities_forward = [2./9, -50, 2./9, 1./9., -50, 1./9, 0.0, 0.0, 0.0]
  575. self.probabilities_turn = [1./9, 1./9, 1./9, 0.0, -50, 0.0, 0.0, 0.0, 0.0]
  576. self.offset = (-1,0)
  577. def can_overlap(self):
  578. return True
  579. def render(self, img):
  580. c = (100, 100, 200)
  581. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  582. for i in range(6):
  583. ylo = 0.1 + 0.15 * i
  584. yhi = 0.20 + 0.15 * i
  585. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  586. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  587. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  588. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  589. class SlipperyWest(WorldObj):
  590. def __init__(self, color: str = "blue"):
  591. super().__init__("slipperywest", color)
  592. self.probabilities_forward = [0.0, 0.0, 0.0, 1./9., -50, 1./9, 2./9, -50, 2./9]
  593. self.probabilities_turn = [0.0, 0.0, 0.0, 0.0, -50, 0.0, 1./9, 1./9, 1./9]
  594. self.offset = (1,0)
  595. def can_overlap(self):
  596. return True
  597. def render(self, img):
  598. c = (100, 100, 200)
  599. fill_coords(img, point_in_rect(0, 1, 0, 1), c)
  600. for i in range(6):
  601. ylo = 0.1 + 0.15 * i
  602. yhi = 0.20 + 0.15 * i
  603. fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
  604. fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
  605. fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
  606. fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
  607. #class Adversary(WorldObj):
  608. # def __init__(self, adversary_dir=1, color="blue", tasks=[DoRandom()]):
  609. # super().__init__("adversary", color)
  610. # self.adversary_dir = adversary_dir
  611. # self.color = color
  612. # self.task_manager = TaskManager(tasks)
  613. # self.carrying = None
  614. #
  615. # def render(self, img):
  616. # tri_fn = point_in_triangle(
  617. # (0.12, 0.19),
  618. # (0.87, 0.50),
  619. # (0.12, 0.81),
  620. # )
  621. #
  622. # # Rotate the agent based on its direction
  623. # tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * self.adversary_dir)
  624. # fill_coords(img, tri_fn, COLORS[self.color])
  625. #
  626. # def dir_vec(self):
  627. # assert self.adversary_dir >= 0 and self.adversary_dir < 4
  628. # return DIR_TO_VEC[self.adversary_dir]
  629. #
  630. # def can_overlap(self):
  631. # return False
  632. #
  633. # def encode(self):
  634. # return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], self.adversary_dir)
  635. tri_22 = point_in_triangle(
  636. (1.0, 0.0),
  637. (0.0, 0.0),
  638. (0.5, 0.5),
  639. )
  640. tri_33 = point_in_triangle(
  641. (0.0, 0.0),
  642. (0.0, 1.0),
  643. (0.5, 0.5),
  644. )
  645. tri_00 = point_in_triangle(
  646. (0.0, 1.0),
  647. (1.0, 1.0),
  648. (0.5, 0.5),
  649. )
  650. tri_11 = point_in_triangle(
  651. (1.0, 1.0),
  652. (1.0, 0.0),
  653. (0.5, 0.5),
  654. )
  655. def tri_mask(width, heigth):
  656. list22 = []
  657. list33 = []
  658. list00 = []
  659. list11 = []
  660. for x in range(width):
  661. for y in range(heigth):
  662. yf = (y + 0.5) / heigth
  663. xf = (x + 0.5) / width
  664. if tri_22(xf,yf):
  665. list22.append((x,y))
  666. elif tri_33(xf,yf):
  667. list33.append((x,y))
  668. elif tri_00(xf,yf):
  669. list00.append((x,y))
  670. elif tri_11(xf,yf):
  671. list11.append((x,y))
  672. return np.array(list22), np.array(list33), np.array(list00), np.array(list11)
  673. list22, list33, list00, list11 = tri_mask(96, 96)
  674. class HeatMapTile(WorldObj):
  675. def __init__(self, tile=dict(), ranking = [], nr_bins=5, color="blue"):
  676. super().__init__("heattile", color)
  677. self.tile_values = tile
  678. self.ranking = ranking
  679. self.nr_bins = nr_bins
  680. def can_overlap(self):
  681. return True
  682. def can_contain(self):
  683. return True
  684. def encode(self):
  685. """Encode the a description of this object as a 3-tuple of integers"""
  686. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX[self.color], 0)
  687. def render(self, img):
  688. img[tuple(list22.T)] = color_states(self.tile_values[2], self.ranking, self.nr_bins)
  689. img[tuple(list33.T)] = color_states(self.tile_values[3], self.ranking, self.nr_bins)
  690. img[tuple(list00.T)] = color_states(self.tile_values[0], self.ranking, self.nr_bins)
  691. img[tuple(list11.T)] = color_states(self.tile_values[1], self.ranking, self.nr_bins)
  692. class HeatMapTileReduced(WorldObj):
  693. def __init__(self, ranking_value=0, ranking=[], nr_bins=5, color="blue"):
  694. super().__init__("heattilereduced", color)
  695. self.ranking = ranking
  696. self.nr_bins = nr_bins
  697. self.color = color_states(ranking_value, self.ranking, self.nr_bins)
  698. def can_overlap(self):
  699. return True
  700. def can_contain(self):
  701. return True
  702. def encode(self):
  703. """Encode the a description of this object as a 3-tuple of integers"""
  704. return (OBJECT_TO_IDX[self.type], COLOR_TO_IDX["blue"], 0)
  705. def render(self, img):
  706. fill_coords(img, point_in_rect(0, 1, 0, 1), self.color)
  707. class Grid:
  708. """
  709. Represent a grid and operations on it
  710. """
  711. # Static cache of pre-renderer tiles
  712. tile_cache = {}
  713. def __init__(self, width, height):
  714. assert width >= 3
  715. assert height >= 3
  716. self.width = width
  717. self.height = height
  718. self.grid = [None] * width * height
  719. self.background = [None] * width * height
  720. def __contains__(self, key):
  721. if isinstance(key, WorldObj):
  722. for e in self.grid:
  723. if e is key:
  724. return True
  725. elif isinstance(key, tuple):
  726. for e in self.grid:
  727. if e is None:
  728. continue
  729. if (e.color, e.type) == key:
  730. return True
  731. if key[0] is None and key[1] == e.type:
  732. return True
  733. return False
  734. def __eq__(self, other):
  735. grid1 = self.encode()
  736. grid2 = other.encode()
  737. return np.array_equal(grid2, grid1)
  738. def __ne__(self, other):
  739. return not self == other
  740. def copy(self):
  741. return deepcopy(self)
  742. def set(self, i, j, v):
  743. assert i >= 0 and i < self.width
  744. assert j >= 0 and j < self.height
  745. self.grid[j * self.width + i] = v
  746. def get(self, i, j):
  747. assert i >= 0 and i < self.width
  748. assert j >= 0 and j < self.height
  749. return self.grid[j * self.width + i]
  750. def set_background(self, i, j, v):
  751. assert i >= 0 and i < self.width
  752. assert j >= 0 and j < self.height
  753. self.background[j * self.width + i] = v
  754. def get_background(self, i, j):
  755. assert i >= 0 and i < self.width
  756. assert j >= 0 and j < self.height
  757. return self.background[j * self.width + i]
  758. def horz_wall(self, x, y, length=None, obj_type=Wall):
  759. if length is None:
  760. length = self.width - x
  761. for i in range(0, length):
  762. self.set(x + i, y, obj_type())
  763. def vert_wall(self, x, y, length=None, obj_type=Wall):
  764. if length is None:
  765. length = self.height - y
  766. for j in range(0, length):
  767. self.set(x, y + j, obj_type())
  768. def wall_rect(self, x, y, w, h):
  769. self.horz_wall(x, y, w)
  770. self.horz_wall(x, y + h - 1, w)
  771. self.vert_wall(x, y, h)
  772. self.vert_wall(x + w - 1, y, h)
  773. def rotate_left(self):
  774. """
  775. Rotate the grid to the left (counter-clockwise)
  776. """
  777. grid = Grid(self.height, self.width)
  778. for i in range(self.width):
  779. for j in range(self.height):
  780. v = self.get(i, j)
  781. grid.set(j, grid.height - 1 - i, v)
  782. return grid
  783. def slice(self, topX, topY, width, height):
  784. """
  785. Get a subset of the grid
  786. """
  787. grid = Grid(width, height)
  788. for j in range(0, height):
  789. for i in range(0, width):
  790. x = topX + i
  791. y = topY + j
  792. if x >= 0 and x < self.width and y >= 0 and y < self.height:
  793. v = self.get(x, y)
  794. else:
  795. v = Wall()
  796. grid.set(i, j, v)
  797. return grid
  798. @classmethod
  799. def render_tile(
  800. cls, obj, background, agent_dir=None, highlight=False, tile_size=TILE_PIXELS, subdivs=3, carrying=None, cache=True
  801. ):
  802. """
  803. Render a tile and cache the result
  804. """
  805. # Hash map lookup key for the cache
  806. # this prevents re-rendering tiles
  807. key = (agent_dir, highlight, tile_size)
  808. if obj:
  809. key = key + obj.encode()
  810. if background:
  811. key = key + background.encode()
  812. if carrying:
  813. key = key + carrying.encode()
  814. if key in cls.tile_cache and cache:
  815. return cls.tile_cache[key]
  816. img = np.zeros(
  817. shape=(tile_size * subdivs, tile_size * subdivs, 3), dtype=np.uint8
  818. )
  819. # Draw the grid lines (top and left edges)
  820. fill_coords(img, point_in_rect(0, 0.031, 0, 1), (100, 100, 100))
  821. fill_coords(img, point_in_rect(0, 1, 0, 0.031), (100, 100, 100))
  822. if background is not None:
  823. background.render(img)
  824. if obj is not None:
  825. obj.render(img)
  826. # Overlay the agent on top
  827. if agent_dir is not None:
  828. tri_fn = point_in_triangle(
  829. (0.12, 0.19),
  830. (0.87, 0.50),
  831. (0.12, 0.81),
  832. )
  833. # Rotate the agent based on its direction
  834. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  835. fill_coords(img, tri_fn, (255, 0, 0))
  836. if carrying:
  837. width = 0.15
  838. tri_fn = point_in_triangle(
  839. (0.12+width, 0.19+width),
  840. (0.87-width, 0.50),
  841. (0.12+width, 0.81-width),
  842. )
  843. # Rotate the agent based on its direction
  844. tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * agent_dir)
  845. fill_coords(img, tri_fn, (255, 255, 255))
  846. # Highlight the cell if needed
  847. if highlight:
  848. highlight_img(img)
  849. # Downsample the image to perform supersampling/anti-aliasing
  850. img = downsample(img, subdivs)
  851. # Cache the rendered tile
  852. cls.tile_cache[key] = img
  853. return img
  854. def render(self, tile_size, agent_pos, agent_dir=None, highlight_mask=None, carrying=None, cache=True):
  855. """
  856. Render this grid at a given scale
  857. :param r: target renderer object
  858. :param tile_size: tile size in pixels
  859. """
  860. if highlight_mask is None:
  861. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  862. # Compute the total grid size
  863. width_px = self.width * tile_size
  864. height_px = self.height * tile_size
  865. img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8)
  866. # Render the grid
  867. for j in range(0, self.height):
  868. for i in range(0, self.width):
  869. cell = self.get(i, j)
  870. background = self.get_background(i,j)
  871. agent_here = np.array_equal(agent_pos, (i, j))
  872. tile_img = Grid.render_tile(
  873. cell,
  874. background,
  875. agent_dir=agent_dir if agent_here else None,
  876. highlight=highlight_mask[i, j],
  877. tile_size=tile_size,
  878. carrying=carrying if agent_here else None,
  879. cache=cache
  880. )
  881. ymin = j * tile_size
  882. ymax = (j + 1) * tile_size
  883. xmin = i * tile_size
  884. xmax = (i + 1) * tile_size
  885. img[ymin:ymax, xmin:xmax, :] = tile_img
  886. return img
  887. def encode(self, vis_mask=None):
  888. """
  889. Produce a compact numpy encoding of the grid
  890. """
  891. if vis_mask is None:
  892. vis_mask = np.ones((self.width, self.height), dtype=bool)
  893. array = np.zeros((self.width, self.height, 3), dtype="uint8")
  894. for i in range(self.width):
  895. for j in range(self.height):
  896. if vis_mask[i, j]:
  897. v = self.get(i, j)
  898. if v is None:
  899. array[i, j, 0] = OBJECT_TO_IDX["empty"]
  900. array[i, j, 1] = 0
  901. array[i, j, 2] = 0
  902. else:
  903. array[i, j, :] = v.encode()
  904. return array
  905. @staticmethod
  906. def decode(array):
  907. """
  908. Decode an array grid encoding back into a grid
  909. """
  910. width, height, channels = array.shape
  911. assert channels == 3
  912. vis_mask = np.ones(shape=(width, height), dtype=bool)
  913. grid = Grid(width, height)
  914. for i in range(width):
  915. for j in range(height):
  916. type_idx, color_idx, state = array[i, j]
  917. v = WorldObj.decode(type_idx, color_idx, state)
  918. grid.set(i, j, v)
  919. vis_mask[i, j] = type_idx != OBJECT_TO_IDX["unseen"]
  920. return grid, vis_mask
  921. def process_vis(self, agent_pos):
  922. mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  923. mask[agent_pos[0], agent_pos[1]] = True
  924. for j in reversed(range(0, self.height)):
  925. for i in range(0, self.width - 1):
  926. if not mask[i, j]:
  927. continue
  928. cell = self.get(i, j)
  929. if cell and not cell.see_behind():
  930. continue
  931. mask[i + 1, j] = True
  932. if j > 0:
  933. mask[i + 1, j - 1] = True
  934. mask[i, j - 1] = True
  935. for i in reversed(range(1, self.width)):
  936. if not mask[i, j]:
  937. continue
  938. cell = self.get(i, j)
  939. if cell and not cell.see_behind():
  940. continue
  941. mask[i - 1, j] = True
  942. if j > 0:
  943. mask[i - 1, j - 1] = True
  944. mask[i, j - 1] = True
  945. for j in range(0, self.height):
  946. for i in range(0, self.width):
  947. if not mask[i, j]:
  948. self.set(i, j, None)
  949. return mask
  950. def get_neighbours(self, i, j):
  951. neighbours = list()
  952. potential_neighbours = [(i-1,j), (i,j+1), (i+1,j), (i,j-1)]
  953. for n in potential_neighbours:
  954. cell = self.get(*n)
  955. if cell is None or (cell.can_overlap() and not isinstance(cell, Lava)):
  956. neighbours.append(n)
  957. return neighbours
  958. class MiniGridEnv(gym.Env):
  959. """
  960. 2D grid world game environment
  961. """
  962. metadata = {
  963. "render_modes": ["human", "rgb_array"],
  964. "render_fps": 10,
  965. }
  966. # Enumeration of possible actions
  967. class Actions(IntEnum):
  968. # Turn left, turn right, move forward
  969. left = 0
  970. right = 1
  971. forward = 2
  972. # Pick up an object
  973. pickup = 3
  974. # Drop an object
  975. drop = 4
  976. # Toggle/activate an object
  977. toggle = 5
  978. # Done completing task
  979. done = 6
  980. def __init__(
  981. self,
  982. mission_space: MissionSpace,
  983. grid_size: int = None,
  984. width: int = None,
  985. height: int = None,
  986. max_steps: int = 100,
  987. see_through_walls: bool = False,
  988. agent_view_size: int = 7,
  989. render_mode: Optional[str] = None,
  990. highlight: bool = True,
  991. tile_size: int = TILE_PIXELS,
  992. agent_pov: bool = False,
  993. ):
  994. # Initialize mission
  995. self.mission = mission_space.sample()
  996. self.adversaries = list()
  997. self.background_boxes = dict()
  998. self.bfs_reward = list()
  999. self.violation_deque = deque(maxlen=12)
  1000. # Can't set both grid_size and width/height
  1001. if grid_size:
  1002. assert width is None and height is None
  1003. width = grid_size
  1004. height = grid_size
  1005. # Action enumeration for this environment
  1006. self.actions = MiniGridEnv.Actions
  1007. # Actions are discrete integer values
  1008. self.action_space = spaces.Discrete(len(self.actions))
  1009. # Number of cells (width and height) in the agent view
  1010. assert agent_view_size % 2 == 1
  1011. assert agent_view_size >= 3
  1012. self.agent_view_size = agent_view_size
  1013. # Observations are dictionaries containing an
  1014. # encoding of the grid and a textual 'mission' string
  1015. image_observation_space = spaces.Box(
  1016. low=0,
  1017. high=255,
  1018. shape=(self.agent_view_size, self.agent_view_size, 3),
  1019. dtype="uint8",
  1020. )
  1021. self.observation_space = spaces.Dict(
  1022. {
  1023. "image": image_observation_space,
  1024. "direction": spaces.Discrete(4),
  1025. "mission": mission_space,
  1026. }
  1027. )
  1028. # Range of possible rewards
  1029. self.reward_range = (0, 1)
  1030. self.window: Window = None
  1031. # Environment configuration
  1032. self.width = width
  1033. self.height = height
  1034. self.max_steps = max_steps
  1035. self.see_through_walls = see_through_walls
  1036. # Current position and direction of the agent
  1037. self.agent_pos: np.ndarray = None
  1038. self.agent_dir: int = None
  1039. # Current grid and mission and carryinh
  1040. self.grid = Grid(width, height)
  1041. self.colorful_grid = Grid(width, height)
  1042. self.carrying = None
  1043. # Rendering attributes
  1044. self.render_mode = render_mode
  1045. self.highlight = highlight
  1046. self.tile_size = tile_size
  1047. self.agent_pov = agent_pov
  1048. # safety violations
  1049. self.safety_violations = []
  1050. self.safety_violations_timesteps = []
  1051. self.total_timesteps = 0
  1052. self.safety_violations_this_episode = None
  1053. self.episode_count = 0
  1054. def reset(self, *, state=None, seed=None, options=None):
  1055. super().reset(seed=seed)
  1056. # Reinitialize episode-specific variables
  1057. if state:
  1058. self.agent_pos = (state.pos_x, state.pos_y)
  1059. self.agent_dir = state.dir
  1060. else:
  1061. self.agent_pos = (-1, -1)
  1062. self.agent_dir = -1
  1063. # Generate a new random grid at the start of each episode
  1064. self._gen_grid(self.width, self.height)
  1065. # These fields should be defined by _gen_grid
  1066. assert (
  1067. self.agent_pos >= (0, 0)
  1068. if isinstance(self.agent_pos, tuple)
  1069. else all(self.agent_pos >= 0) and self.agent_dir >= 0
  1070. )
  1071. # Check that the agent doesn't overlap with an object
  1072. start_cell = self.grid.get(*self.agent_pos)
  1073. assert start_cell is None or start_cell.can_overlap()
  1074. # Item picked up, being carried, initially nothing
  1075. self.carrying = None
  1076. # Step count since episode start
  1077. self.step_count = 0
  1078. if self.render_mode == "human":
  1079. self.render()
  1080. # Return first observation
  1081. obs = self.gen_obs()
  1082. if self.safety_violations_this_episode is not None:
  1083. self.safety_violations.append(self.safety_violations_this_episode)
  1084. self.safety_violations_timesteps.append(self.total_timesteps)
  1085. self.safety_violations_this_episode = 0
  1086. self.episode_count += 1
  1087. #input("Episode End, Hit Enter.")
  1088. self.violation_deque.clear()
  1089. return obs, {}
  1090. def hash(self, size=16):
  1091. """Compute a hash that uniquely identifies the current state of the environment.
  1092. :param size: Size of the hashing
  1093. """
  1094. sample_hash = hashlib.sha256()
  1095. to_encode = [self.grid.encode().tolist(), self.agent_pos, self.agent_dir]
  1096. for item in to_encode:
  1097. sample_hash.update(str(item).encode("utf8"))
  1098. return sample_hash.hexdigest()[:size]
  1099. @property
  1100. def steps_remaining(self):
  1101. return self.max_steps - self.step_count
  1102. def printGrid(self, init=False):
  1103. """
  1104. Produce a pretty string of the environment's grid along with the agent.
  1105. A grid cell is represented by 2-character string, the first one for
  1106. the object and the second one for the color.
  1107. """
  1108. if init:
  1109. self._gen_grid(self.grid.width, self.grid.height) # todo need to add this for minigrid2prism
  1110. print("Dimensions: {} x {}".format(self.grid.height, self.grid.width))
  1111. self._gen_grid(self.grid.width, self.grid.height)
  1112. # Map of object types to short string
  1113. OBJECT_TO_STR = {
  1114. "wall": "W",
  1115. "floor": "F",
  1116. "door": "D",
  1117. "key": "K",
  1118. "ball": "A",
  1119. "box": "B",
  1120. "goal": "G",
  1121. "lava": "V",
  1122. "adversary": "Z",
  1123. "slipperynorth": "n",
  1124. "slipperysouth": "s",
  1125. "slipperyeast": "e",
  1126. "slipperywest": "w"
  1127. }
  1128. # Map agent's direction to short string
  1129. AGENT_DIR_TO_STR = {0: ">", 1: "V", 2: "<", 3: "^"}
  1130. str = ""
  1131. bfs_rewards = list()
  1132. background_str = ""
  1133. for j in range(self.grid.height):
  1134. for i in range(self.grid.width):
  1135. b = self.grid.get_background(i, j)
  1136. c = self.grid.get(i, j)
  1137. if init:
  1138. if c and c.type == "wall":
  1139. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1140. elif c and c.type == "slipperynorth":
  1141. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1142. elif c and c.type == "slipperysouth":
  1143. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1144. elif c and c.type == "slipperyeast":
  1145. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1146. elif c and c.type == "slipperywest":
  1147. background_str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1148. elif b is None:
  1149. background_str += " "
  1150. else:
  1151. if b.type != "floor":
  1152. type_str = OBJECT_TO_STR[b.type]
  1153. else:
  1154. type_str = " "
  1155. background_str += type_str + b.color.replace("light","")[0].upper()
  1156. if self.bfs_reward:
  1157. bfs_rewards.append(f"{i};{j};{self.bfs_reward[i + self.grid.width * j]}")
  1158. if self.agent_pos is not None and i == self.agent_pos[0] and j == self.agent_pos[1]:
  1159. #str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  1160. if init:
  1161. str += "XR"
  1162. else:
  1163. str += 2 * AGENT_DIR_TO_STR[self.agent_dir]
  1164. continue
  1165. if c is None:
  1166. #print("{}, {}".format(i,j,), end="")
  1167. str += " "
  1168. continue
  1169. #print("{}, {}: {}{}".format(i,j,OBJECT_TO_STR[c.type], c.color[0]), end="")
  1170. if c.type == "door":
  1171. if c.is_open:
  1172. str += "__"
  1173. elif c.is_locked:
  1174. str += "L" + c.color[0].upper()
  1175. else:
  1176. str += "D" + c.color[0].upper()
  1177. continue
  1178. if not init and c.type == "adversary":
  1179. str += AGENT_DIR_TO_STR[c.adversary_dir] + c.color[0].upper()
  1180. continue
  1181. str += OBJECT_TO_STR[c.type] + c.color[0].upper()
  1182. if j < self.grid.height - 1:
  1183. str += "\n"
  1184. if init:
  1185. background_str += "\n"
  1186. #print("")
  1187. if init and self.bfs_reward:
  1188. return str + "\n" + "-" * self.grid.width * 2 + "\n" + background_str + "\n" + "-" * self.grid.width * 2 + "\n" + ";".join(bfs_rewards)
  1189. else:
  1190. return str + "\n" + "-" * self.grid.width * 2 + "\n" + background_str
  1191. @abstractmethod
  1192. def _gen_grid(self, width, height):
  1193. pass
  1194. def _reward(self):
  1195. """
  1196. Compute the reward to be given upon success
  1197. """
  1198. return 1# - 0.9 * (self.step_count / self.max_steps)
  1199. def _rand_int(self, low, high):
  1200. """
  1201. Generate random integer in [low,high[
  1202. """
  1203. return self.np_random.integers(low, high)
  1204. def _rand_float(self, low, high):
  1205. """
  1206. Generate random float in [low,high[
  1207. """
  1208. return self.np_random.uniform(low, high)
  1209. def _rand_bool(self):
  1210. """
  1211. Generate random boolean value
  1212. """
  1213. return self.np_random.integers(0, 2) == 0
  1214. def _rand_elem(self, iterable):
  1215. """
  1216. Pick a random element in a list
  1217. """
  1218. lst = list(iterable)
  1219. idx = self._rand_int(0, len(lst))
  1220. return lst[idx]
  1221. def _rand_subset(self, iterable, num_elems):
  1222. """
  1223. Sample a random subset of distinct elements of a list
  1224. """
  1225. lst = list(iterable)
  1226. assert num_elems <= len(lst)
  1227. out = []
  1228. while len(out) < num_elems:
  1229. elem = self._rand_elem(lst)
  1230. lst.remove(elem)
  1231. out.append(elem)
  1232. return out
  1233. def _rand_color(self):
  1234. """
  1235. Generate a random color name (string)
  1236. """
  1237. return self._rand_elem(COLOR_NAMES)
  1238. def _rand_pos(self, xLow, xHigh, yLow, yHigh):
  1239. """
  1240. Generate a random (x,y) position tuple
  1241. """
  1242. return (
  1243. self.np_random.integers(xLow, xHigh),
  1244. self.np_random.integers(yLow, yHigh),
  1245. )
  1246. def place_obj(self, obj, top=None, size=None, reject_fn=None, max_tries=math.inf):
  1247. """
  1248. Place an object at an empty position in the grid
  1249. :param top: top-left position of the rectangle where to place
  1250. :param size: size of the rectangle where to place
  1251. :param reject_fn: function to filter out potential positions
  1252. """
  1253. if top is None:
  1254. top = (0, 0)
  1255. else:
  1256. top = (max(top[0], 0), max(top[1], 0))
  1257. if size is None:
  1258. size = (self.grid.width, self.grid.height)
  1259. num_tries = 0
  1260. while True:
  1261. # This is to handle with rare cases where rejection sampling
  1262. # gets stuck in an infinite loop
  1263. if num_tries > max_tries:
  1264. raise RecursionError("rejection sampling failed in place_obj")
  1265. num_tries += 1
  1266. pos = np.array(
  1267. (
  1268. self._rand_int(top[0], min(top[0] + size[0], self.grid.width)),
  1269. self._rand_int(top[1], min(top[1] + size[1], self.grid.height)),
  1270. )
  1271. )
  1272. pos = tuple(pos)
  1273. # Don't place the object on top of another object
  1274. if self.grid.get(*pos) is not None:
  1275. continue
  1276. # Don't place the object where the agent is
  1277. if np.array_equal(pos, self.agent_pos):
  1278. continue
  1279. # Check if there is a filtering criterion
  1280. if reject_fn and reject_fn(self, pos):
  1281. continue
  1282. break
  1283. self.grid.set(pos[0], pos[1], obj)
  1284. if obj is not None:
  1285. obj.init_pos = pos
  1286. obj.cur_pos = pos
  1287. return pos
  1288. def put_obj(self, obj, i, j):
  1289. """
  1290. Put an object at a specific position in the grid
  1291. """
  1292. self.grid.set(i, j, obj)
  1293. obj.init_pos = (i, j)
  1294. obj.cur_pos = (i, j)
  1295. def place_agent(self, top=None, size=None, rand_dir=True, max_tries=math.inf):
  1296. """
  1297. Set the agent's starting point at an empty position in the grid
  1298. """
  1299. self.agent_pos = (-1, -1)
  1300. pos = self.place_obj(None, top, size, max_tries=max_tries)
  1301. self.agent_pos = pos
  1302. if rand_dir:
  1303. self.agent_dir = self._rand_int(0, 4)
  1304. return pos
  1305. def add_slippery_tile(
  1306. self,
  1307. i: int,
  1308. j: int,
  1309. type: str
  1310. ):
  1311. """
  1312. Adds a slippery tile to the grid
  1313. """
  1314. if type=="slipperynorth":
  1315. slippery_tile = SlipperyNorth()
  1316. elif type=="slipperysouth":
  1317. slippery_tile = SlipperySouth()
  1318. elif type=="slipperyeast":
  1319. slippery_tile = SlipperyEast()
  1320. elif type=="slipperywest":
  1321. slippery_tile = SlipperyWest()
  1322. else:
  1323. slippery_tile = SlipperyNorth()
  1324. self.grid.set(i, j, slippery_tile)
  1325. return (i, j)
  1326. #def add_adversary(
  1327. # self,
  1328. # i: int,
  1329. # j: int,
  1330. # color: str,
  1331. # direction: int = 0,
  1332. # tasks: List[Task] = [DoRandom()]
  1333. #):
  1334. # """
  1335. # Adds a slippery tile to the grid
  1336. # """
  1337. # slippery_tile = Slippery()
  1338. # adv = Adversary(direction,color, tasks=tasks)
  1339. # self.put_obj(adv,i,j)
  1340. # self.adversaries[color] = adv
  1341. # return (i, j)
  1342. @property
  1343. def dir_vec(self):
  1344. """
  1345. Get the direction vector for the agent, pointing in the direction
  1346. of forward movement.
  1347. """
  1348. assert self.agent_dir >= 0 and self.agent_dir < 4
  1349. return DIR_TO_VEC[self.agent_dir]
  1350. @property
  1351. def right_vec(self):
  1352. """
  1353. Get the vector pointing to the right of the agent.
  1354. """
  1355. dx, dy = self.dir_vec
  1356. return np.array((-dy, dx))
  1357. @property
  1358. def front_pos(self):
  1359. """
  1360. Get the position of the cell that is right in front of the agent
  1361. """
  1362. return self.agent_pos + self.dir_vec
  1363. def get_view_coords(self, i, j):
  1364. """
  1365. Translate and rotate absolute grid coordinates (i, j) into the
  1366. agent's partially observable view (sub-grid). Note that the resulting
  1367. coordinates may be negative or outside of the agent's view size.
  1368. """
  1369. ax, ay = self.agent_pos
  1370. dx, dy = self.dir_vec
  1371. rx, ry = self.right_vec
  1372. # Compute the absolute coordinates of the top-left view corner
  1373. sz = self.agent_view_size
  1374. hs = self.agent_view_size // 2
  1375. tx = ax + (dx * (sz - 1)) - (rx * hs)
  1376. ty = ay + (dy * (sz - 1)) - (ry * hs)
  1377. lx = i - tx
  1378. ly = j - ty
  1379. # Project the coordinates of the object relative to the top-left
  1380. # corner onto the agent's own coordinate system
  1381. vx = rx * lx + ry * ly
  1382. vy = -(dx * lx + dy * ly)
  1383. return vx, vy
  1384. def get_view_exts(self, agent_view_size=None):
  1385. """
  1386. Get the extents of the square set of tiles visible to the agent
  1387. Note: the bottom extent indices are not included in the set
  1388. if agent_view_size is None, use self.agent_view_size
  1389. """
  1390. agent_view_size = agent_view_size or self.agent_view_size
  1391. # Facing right
  1392. if self.agent_dir == 0:
  1393. topX = self.agent_pos[0]
  1394. topY = self.agent_pos[1] - agent_view_size // 2
  1395. # Facing down
  1396. elif self.agent_dir == 1:
  1397. topX = self.agent_pos[0] - agent_view_size // 2
  1398. topY = self.agent_pos[1]
  1399. # Facing left
  1400. elif self.agent_dir == 2:
  1401. topX = self.agent_pos[0] - agent_view_size + 1
  1402. topY = self.agent_pos[1] - agent_view_size // 2
  1403. # Facing up
  1404. elif self.agent_dir == 3:
  1405. topX = self.agent_pos[0] - agent_view_size // 2
  1406. topY = self.agent_pos[1] - agent_view_size + 1
  1407. else:
  1408. assert False, "invalid agent direction"
  1409. botX = topX + agent_view_size
  1410. botY = topY + agent_view_size
  1411. return (topX, topY, botX, botY)
  1412. def relative_coords(self, x, y):
  1413. """
  1414. Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
  1415. """
  1416. vx, vy = self.get_view_coords(x, y)
  1417. if vx < 0 or vy < 0 or vx >= self.agent_view_size or vy >= self.agent_view_size:
  1418. return None
  1419. return vx, vy
  1420. def in_view(self, x, y):
  1421. """
  1422. check if a grid position is visible to the agent
  1423. """
  1424. return self.relative_coords(x, y) is not None
  1425. def agent_sees(self, x, y):
  1426. """
  1427. Check if a non-empty grid position is visible to the agent
  1428. """
  1429. coordinates = self.relative_coords(x, y)
  1430. if coordinates is None:
  1431. return False
  1432. vx, vy = coordinates
  1433. obs = self.gen_obs()
  1434. obs_grid, _ = Grid.decode(obs["image"])
  1435. obs_cell = obs_grid.get(vx, vy)
  1436. world_cell = self.grid.get(x, y)
  1437. assert world_cell is not None
  1438. return obs_cell is not None and obs_cell.type == world_cell.type
  1439. def get_neighbours_prob_forward(self, agent_pos, probabilities, offset):
  1440. neighbours = [tuple((x,y)) for x in range(agent_pos[0]-1, agent_pos[0]+2) for y in range(agent_pos[1]-1,agent_pos[1]+2)]
  1441. probabilities_dict = dict(zip(neighbours, probabilities))
  1442. for pos in probabilities_dict:
  1443. cell = self.grid.get(*pos)
  1444. if cell is not None and not cell.can_overlap():
  1445. probabilities_dict[pos] = 0.0
  1446. sum_prob = 0
  1447. for pos in probabilities_dict:
  1448. if (probabilities_dict[pos]!=-50):
  1449. sum_prob += probabilities_dict[pos]
  1450. if probabilities_dict[tuple((agent_pos[0] + offset[0], agent_pos[1]+offset[1]))] == 0:
  1451. probabilities_dict[agent_pos] = 1-sum_prob
  1452. else:
  1453. probabilities_dict[tuple((agent_pos[0] + offset[0], agent_pos[1]+offset[1]))] = 1-sum_prob
  1454. probabilities_dict[agent_pos] = 0.0
  1455. #print(probabilities_dict)
  1456. #print(agent_pos+offset)
  1457. return list(probabilities_dict.keys()), list(probabilities_dict.values())
  1458. def get_neighbours_prob_turn(self, agent_pos, probabilities):
  1459. neighbours = [tuple((x,y)) for x in range(agent_pos[0]-1, agent_pos[0]+2) for y in range(agent_pos[1]-1,agent_pos[1]+2)]
  1460. non_blocked_neighbours = []
  1461. i = 0
  1462. non_blocked_prob = []
  1463. for pos in neighbours:
  1464. cell = self.grid.get(*pos)
  1465. if (cell is None or cell.can_overlap()):
  1466. non_blocked_neighbours.append(pos)
  1467. non_blocked_prob.append(probabilities[i])
  1468. i += 1
  1469. sum_prob = 0
  1470. for prob in non_blocked_prob:
  1471. if (prob!=-50):
  1472. sum_prob += prob
  1473. non_blocked_prob = [x if x!=-50 else 1-sum_prob for x in non_blocked_prob]
  1474. return non_blocked_neighbours, non_blocked_prob
  1475. def step(self, action):
  1476. self.step_count += 1
  1477. self.total_timesteps += 1
  1478. reward = 0
  1479. terminated = False
  1480. truncated = False
  1481. # Get the position in front of the agent
  1482. fwd_pos = self.front_pos
  1483. # Get the contents of the cell in front of the agent
  1484. fwd_cell = self.grid.get(*fwd_pos)
  1485. current_cell = self.grid.get(*self.agent_pos)
  1486. ran_into_lava = False
  1487. reached_goal = False
  1488. if action == self.actions.forward and isSlippery(current_cell):
  1489. possible_fwd_pos, prob = self.get_neighbours_prob_forward(self.agent_pos, current_cell.probabilities_forward, current_cell.offset)
  1490. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  1491. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  1492. fwd_cell = self.grid.get(*fwd_pos)
  1493. # Rotate left
  1494. if action == self.actions.left:
  1495. self.agent_dir -= 1
  1496. if self.agent_dir < 0:
  1497. self.agent_dir += 4
  1498. if isSlippery(current_cell):
  1499. possible_fwd_pos, prob = self.get_neighbours_prob_turn(self.agent_pos, current_cell.probabilities_turn)
  1500. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  1501. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  1502. fwd_cell = self.grid.get(*fwd_pos)
  1503. if fwd_cell is None or fwd_cell.can_overlap():
  1504. self.agent_pos = tuple(fwd_pos)
  1505. if fwd_cell is not None and fwd_cell.type == "goal":
  1506. terminated = True
  1507. reached_goal = True
  1508. reward = self._reward()
  1509. if fwd_cell is not None and fwd_cell.type == "lava":
  1510. terminated = True
  1511. ran_into_lava = True
  1512. # Rotate right
  1513. elif action == self.actions.right:
  1514. self.agent_dir = (self.agent_dir + 1) % 4
  1515. if isSlippery(current_cell):
  1516. possible_fwd_pos, prob = self.get_neighbours_prob_turn(self.agent_pos, current_cell.probabilities_turn)
  1517. fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
  1518. fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
  1519. fwd_cell = self.grid.get(*fwd_pos)
  1520. if fwd_cell is None or fwd_cell.can_overlap():
  1521. self.agent_pos = tuple(fwd_pos)
  1522. if fwd_cell is not None and fwd_cell.type == "goal":
  1523. terminated = True
  1524. reached_goal = True
  1525. reward = self._reward()
  1526. if fwd_cell is not None and fwd_cell.type == "lava":
  1527. terminated = True
  1528. ran_into_lava = True
  1529. # move forward
  1530. elif action == self.actions.forward:
  1531. if fwd_cell is None or fwd_cell.can_overlap():
  1532. self.agent_pos = tuple(fwd_pos)
  1533. if fwd_cell is not None and fwd_cell.type == "goal":
  1534. terminated = True
  1535. reached_goal = True
  1536. reward = self._reward()
  1537. if fwd_cell is not None and fwd_cell.type == "lava":
  1538. terminated = True
  1539. ran_into_lava = True
  1540. # Pick up an object
  1541. elif action == self.actions.pickup:
  1542. if fwd_cell and fwd_cell.can_pickup():
  1543. if self.carrying is None:
  1544. if type(fwd_cell) == Box and fwd_cell.color == "red":
  1545. if not fwd_cell.picked_up:
  1546. self.reward += PICKUPREWARD
  1547. fwd_cell.picked_up = True
  1548. self.carrying = fwd_cell
  1549. self.carrying.cur_pos = np.array([-1, -1])
  1550. self.grid.set(fwd_pos[0], fwd_pos[1], None)
  1551. # Drop an object
  1552. elif action == self.actions.drop:
  1553. if not fwd_cell and self.carrying:
  1554. if type(self.carrying == Box) and self.carrying.color == "red" and self.grid.get_background(*fwd_pos) and self.grid.get_background(*fwd_pos).color == "red":
  1555. self.reward += DELIVERREWARD
  1556. terminated = True
  1557. self.carrying.placed_at_destination = True
  1558. self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
  1559. self.carrying.cur_pos = fwd_pos
  1560. self.carrying = None
  1561. # Toggle/activate an object
  1562. elif action == self.actions.toggle:
  1563. if fwd_cell:
  1564. fwd_cell.toggle(self, fwd_pos)
  1565. # Done action (not used by default)
  1566. elif action == self.actions.done:
  1567. pass
  1568. else:
  1569. raise ValueError(f"Unknown action: {action}")
  1570. if self.step_count >= self.max_steps:
  1571. truncated = True
  1572. if self.render_mode == "human":
  1573. self.render()
  1574. obs = self.gen_obs()
  1575. return obs, reward, terminated, truncated, {"pos": self.agent_pos, "dir": self.agent_dir, "ran_into_lava": ran_into_lava, "reached_goal": reached_goal, "is_success": reached_goal}
  1576. def gen_obs_grid(self, agent_view_size=None):
  1577. """
  1578. Generate the sub-grid observed by the agent.
  1579. This method also outputs a visibility mask telling us which grid
  1580. cells the agent can actually see.
  1581. if agent_view_size is None, self.agent_view_size is used
  1582. """
  1583. topX, topY, botX, botY = self.get_view_exts(agent_view_size)
  1584. agent_view_size = agent_view_size or self.agent_view_size
  1585. grid = self.grid.slice(topX, topY, agent_view_size, agent_view_size)
  1586. for i in range(self.agent_dir + 1):
  1587. grid = grid.rotate_left()
  1588. # Process occluders and visibility
  1589. # Note that this incurs some performance cost
  1590. if not self.see_through_walls:
  1591. vis_mask = grid.process_vis(
  1592. agent_pos=(agent_view_size // 2, agent_view_size - 1)
  1593. )
  1594. else:
  1595. vis_mask = np.ones(shape=(grid.width, grid.height), dtype=bool)
  1596. # Make it so the agent sees what it's carrying
  1597. # We do this by placing the carried object at the agent's position
  1598. # in the agent's partially observable view
  1599. agent_pos = grid.width // 2, grid.height - 1
  1600. if self.carrying:
  1601. grid.set(*agent_pos, self.carrying)
  1602. else:
  1603. grid.set(*agent_pos, None)
  1604. return grid, vis_mask
  1605. def gen_obs(self):
  1606. """
  1607. Generate the agent's view (partially observable, low-resolution encoding)
  1608. """
  1609. grid, vis_mask = self.gen_obs_grid()
  1610. # Encode the partially observable view into a numpy array
  1611. image = grid.encode(vis_mask)
  1612. # Observations are dictionaries containing:
  1613. # - an image (partially observable view of the environment)
  1614. # - the agent's direction/orientation (acting as a compass)
  1615. # - a textual mission string (instructions for the agent)
  1616. obs = {"image": image, "direction": self.agent_dir, "mission": self.mission}
  1617. return obs
  1618. def get_pov_render(self, tile_size):
  1619. """
  1620. Render an agent's POV observation for visualization
  1621. """
  1622. grid, vis_mask = self.gen_obs_grid()
  1623. # Render the whole grid
  1624. img = grid.render(
  1625. tile_size,
  1626. agent_pos=(self.agent_view_size // 2, self.agent_view_size - 1),
  1627. agent_dir=3,
  1628. highlight_mask=vis_mask,
  1629. carrying=self.carrying,
  1630. )
  1631. return img
  1632. def get_full_render(self, highlight, tile_size):
  1633. """
  1634. Render a non-paratial observation for visualization
  1635. """
  1636. # Compute which cells are visible to the agent
  1637. _, vis_mask = self.gen_obs_grid()
  1638. # Compute the world coordinates of the bottom-left corner
  1639. # of the agent's view area
  1640. f_vec = self.dir_vec
  1641. r_vec = self.right_vec
  1642. top_left = (
  1643. self.agent_pos
  1644. + f_vec * (self.agent_view_size - 1)
  1645. - r_vec * (self.agent_view_size // 2)
  1646. )
  1647. # Mask of which cells to highlight
  1648. highlight_mask = np.zeros(shape=(self.width, self.height), dtype=bool)
  1649. # For each cell in the visibility mask
  1650. for vis_j in range(0, self.agent_view_size):
  1651. for vis_i in range(0, self.agent_view_size):
  1652. # If this cell is not visible, don't highlight it
  1653. if not vis_mask[vis_i, vis_j]:
  1654. continue
  1655. # Compute the world coordinates of this cell
  1656. abs_i, abs_j = top_left - (f_vec * vis_j) + (r_vec * vis_i)
  1657. if abs_i < 0 or abs_i >= self.width:
  1658. continue
  1659. if abs_j < 0 or abs_j >= self.height:
  1660. continue
  1661. # Mark this cell to be highlighted
  1662. highlight_mask[abs_i, abs_j] = True
  1663. # Render the whole grid
  1664. img = self.grid.render(
  1665. tile_size,
  1666. self.agent_pos,
  1667. self.agent_dir,
  1668. highlight_mask=highlight_mask if highlight else None,
  1669. carrying=self.carrying,
  1670. )
  1671. return img
  1672. def get_frame(
  1673. self,
  1674. highlight: bool = True,
  1675. tile_size: int = TILE_PIXELS,
  1676. agent_pov: bool = False,
  1677. ):
  1678. """Returns an RGB image corresponding to the whole environment or the agent's point of view.
  1679. Args:
  1680. highlight (bool): If true, the agent's field of view or point of view is highlighted with a lighter gray color.
  1681. tile_size (int): How many pixels will form a tile from the NxM grid.
  1682. agent_pov (bool): If true, the rendered frame will only contain the point of view of the agent.
  1683. Returns:
  1684. frame (np.ndarray): A frame of type numpy.ndarray with shape (x, y, 3) representing RGB values for the x-by-y pixel image.
  1685. """
  1686. if agent_pov:
  1687. return self.get_pov_render(tile_size)
  1688. else:
  1689. return self.get_full_render(highlight, tile_size)
  1690. def render(self, mode=""):
  1691. img = self.get_frame(self.highlight, self.tile_size, self.agent_pov)
  1692. if mode == "human":
  1693. if self.window is None:
  1694. self.window = Window("gym_minigrid")
  1695. self.window.show(block=False)
  1696. self.window.set_caption(self.mission)
  1697. self.window.show_img(img)
  1698. elif mode == "rgb_array":
  1699. return img
  1700. def close(self):
  1701. if self.window:
  1702. self.window.close()