Browse Source

changed one hot wrapping

refactoring
Thomas Knoll 1 year ago
parent
commit
5745113179
  1. 10
      examples/shields/rl/wrappers.py

10
examples/shields/rl/wrappers.py

@ -3,6 +3,7 @@ import numpy as np
import random
from minigrid.core.actions import Actions
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
from gymnasium.spaces import Dict, Box
from collections import deque
@ -18,7 +19,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
self.framestack = framestack
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (16 + 6 + 3) + 4
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
@ -66,11 +67,10 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=16)
colors = one_hot(image[:, :, 1], depth=6)
states = one_hot(image[:, :, 2], depth=3)
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
colors = one_hot(image[:, :, 1], depth=len(COLORS))
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))

Loading…
Cancel
Save