|
@ -1,5 +1,6 @@ |
|
|
import gymnasium as gym |
|
|
import gymnasium as gym |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
import random |
|
|
|
|
|
|
|
|
from minigrid.core.actions import Actions |
|
|
from minigrid.core.actions import Actions |
|
|
|
|
|
|
|
@ -15,9 +16,9 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
def __init__(self, env, vector_index, framestack): |
|
|
def __init__(self, env, vector_index, framestack): |
|
|
super().__init__(env) |
|
|
super().__init__(env) |
|
|
self.framestack = framestack |
|
|
self.framestack = framestack |
|
|
# 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. |
|
|
|
|
|
|
|
|
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. |
|
|
# +4: Direction. |
|
|
# +4: Direction. |
|
|
self.single_frame_dim = 49 * (11 + 6 + 3) + 4 |
|
|
|
|
|
|
|
|
self.single_frame_dim = 49 * (16 + 6 + 3) + 4 |
|
|
self.init_x = None |
|
|
self.init_x = None |
|
|
self.init_y = None |
|
|
self.init_y = None |
|
|
self.x_positions = [] |
|
|
self.x_positions = [] |
|
@ -66,8 +67,8 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
|
|
|
|
|
|
image = obs["data"] |
|
|
image = obs["data"] |
|
|
|
|
|
|
|
|
# One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. |
|
|
|
|
|
objects = one_hot(image[:, :, 0], depth=11) |
|
|
|
|
|
|
|
|
# One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten. |
|
|
|
|
|
objects = one_hot(image[:, :, 0], depth=16) |
|
|
colors = one_hot(image[:, :, 1], depth=6) |
|
|
colors = one_hot(image[:, :, 1], depth=6) |
|
|
states = one_hot(image[:, :, 2], depth=3) |
|
|
states = one_hot(image[:, :, 2], depth=3) |
|
|
|
|
|
|
|
@ -117,10 +118,13 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
|
allowed_actions = self.shield[cur_pos_str] |
|
|
allowed_actions = self.shield[cur_pos_str] |
|
|
for allowed_action in allowed_actions: |
|
|
for allowed_action in allowed_actions: |
|
|
index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set |
|
|
|
|
|
|
|
|
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set |
|
|
if index is None: |
|
|
if index is None: |
|
|
assert(False) |
|
|
assert(False) |
|
|
mask[index] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
|
|
|
|
mask[index] = allowed |
|
|
|
|
|
|
|
|
else: |
|
|
else: |
|
|
for index, x in enumerate(mask): |
|
|
for index, x in enumerate(mask): |
|
|
mask[index] = 1.0 |
|
|
mask[index] = 1.0 |
|
@ -197,11 +201,11 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
|
allowed_actions = self.shield[cur_pos_str] |
|
|
allowed_actions = self.shield[cur_pos_str] |
|
|
for allowed_action in allowed_actions: |
|
|
for allowed_action in allowed_actions: |
|
|
index = get_action_index_mapping(allowed_action[1]) |
|
|
|
|
|
|
|
|
index = get_action_index_mapping(allowed_action.labels) |
|
|
if index is None: |
|
|
if index is None: |
|
|
assert(False) |
|
|
assert(False) |
|
|
|
|
|
|
|
|
mask[index] = 1.0 |
|
|
|
|
|
|
|
|
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
|
else: |
|
|
else: |
|
|
for index, x in enumerate(mask): |
|
|
for index, x in enumerate(mask): |
|
|
mask[index] = 1.0 |
|
|
mask[index] = 1.0 |
|
|