|
@ -3,6 +3,7 @@ import numpy as np |
|
|
import random |
|
|
import random |
|
|
|
|
|
|
|
|
from minigrid.core.actions import Actions |
|
|
from minigrid.core.actions import Actions |
|
|
|
|
|
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX |
|
|
|
|
|
|
|
|
from gymnasium.spaces import Dict, Box |
|
|
from gymnasium.spaces import Dict, Box |
|
|
from collections import deque |
|
|
from collections import deque |
|
@ -18,7 +19,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
self.framestack = framestack |
|
|
self.framestack = framestack |
|
|
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. |
|
|
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. |
|
|
# +4: Direction. |
|
|
# +4: Direction. |
|
|
self.single_frame_dim = 49 * (16 + 6 + 3) + 4 |
|
|
|
|
|
|
|
|
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 |
|
|
self.init_x = None |
|
|
self.init_x = None |
|
|
self.init_y = None |
|
|
self.init_y = None |
|
|
self.x_positions = [] |
|
|
self.x_positions = [] |
|
@ -66,11 +67,10 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
self.y_positions.append(self.agent_pos[1]) |
|
|
self.y_positions.append(self.agent_pos[1]) |
|
|
|
|
|
|
|
|
image = obs["data"] |
|
|
image = obs["data"] |
|
|
|
|
|
|
|
|
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. |
|
|
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. |
|
|
objects = one_hot(image[:, :, 0], depth=16) |
|
|
|
|
|
colors = one_hot(image[:, :, 1], depth=6) |
|
|
|
|
|
states = one_hot(image[:, :, 2], depth=3) |
|
|
|
|
|
|
|
|
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) |
|
|
|
|
|
colors = one_hot(image[:, :, 1], depth=len(COLORS)) |
|
|
|
|
|
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) |
|
|
|
|
|
|
|
|
all_ = np.concatenate([objects, colors, states], -1) |
|
|
all_ = np.concatenate([objects, colors, states], -1) |
|
|
all_flat = np.reshape(all_, (-1,)) |
|
|
all_flat = np.reshape(all_, (-1,)) |
|
|