From 574511317901e9950eab4ed13c432a60f3748f79 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Fri, 22 Sep 2023 15:24:24 +0200 Subject: [PATCH] changed one hot wrapping --- examples/shields/rl/wrappers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index 74d0c7e..4f1d94d 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -3,6 +3,7 @@ import numpy as np import random from minigrid.core.actions import Actions +from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX from gymnasium.spaces import Dict, Box from collections import deque @@ -18,7 +19,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): self.framestack = framestack # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. # +4: Direction. - self.single_frame_dim = 49 * (16 + 6 + 3) + 4 + self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 self.init_x = None self.init_y = None self.x_positions = [] @@ -66,11 +67,10 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): self.y_positions.append(self.agent_pos[1]) image = obs["data"] - # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. - objects = one_hot(image[:, :, 0], depth=16) - colors = one_hot(image[:, :, 1], depth=6) - states = one_hot(image[:, :, 2], depth=3) + objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) + colors = one_hot(image[:, :, 1], depth=len(COLORS)) + states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) all_ = np.concatenate([objects, colors, states], -1) all_flat = np.reshape(all_, (-1,))