|
@ -8,7 +8,7 @@ from gymnasium.spaces import Dict, Box |
|
|
from collections import deque |
|
|
from collections import deque |
|
|
from ray.rllib.utils.numpy import one_hot |
|
|
from ray.rllib.utils.numpy import one_hot |
|
|
|
|
|
|
|
|
from helpers import get_action_index_mapping, extract_keys |
|
|
|
|
|
|
|
|
from helpers import get_action_index_mapping |
|
|
from shieldhandlers import ShieldHandler |
|
|
from shieldhandlers import ShieldHandler |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -67,7 +67,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
|
|
|
|
|
|
image = obs["data"] |
|
|
image = obs["data"] |
|
|
|
|
|
|
|
|
# One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten. |
|
|
|
|
|
|
|
|
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. |
|
|
objects = one_hot(image[:, :, 0], depth=16) |
|
|
objects = one_hot(image[:, :, 0], depth=16) |
|
|
colors = one_hot(image[:, :, 1], depth=6) |
|
|
colors = one_hot(image[:, :, 1], depth=6) |
|
|
states = one_hot(image[:, :, 2], depth=3) |
|
|
states = one_hot(image[:, :, 2], depth=3) |
|
@ -108,7 +108,9 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
|
|
|
|
|
|
cur_pos_str = self.shield_query_creator(self.env) |
|
|
cur_pos_str = self.shield_query_creator(self.env) |
|
|
|
|
|
|
|
|
|
|
|
# print(F"Pos string {cur_pos_str}") |
|
|
|
|
|
# print(F"Shield {list(self.shield.keys())[0]}") |
|
|
|
|
|
# print(cur_pos_str in self.shield) |
|
|
# Create the mask |
|
|
# Create the mask |
|
|
# If shield restricts action mask only valid with 1.0 |
|
|
# If shield restricts action mask only valid with 1.0 |
|
|
# else set all actions as valid |
|
|
# else set all actions as valid |
|
@ -120,6 +122,8 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
for allowed_action in allowed_actions: |
|
|
for allowed_action in allowed_actions: |
|
|
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set |
|
|
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set |
|
|
if index is None: |
|
|
if index is None: |
|
|
|
|
|
print(F"No mapping for action {list(allowed_action.labels)}") |
|
|
|
|
|
print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") |
|
|
assert(False) |
|
|
assert(False) |
|
|
|
|
|
|
|
|
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
|
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
@ -134,8 +138,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
if front_tile is not None and front_tile.type == "key": |
|
|
if front_tile is not None and front_tile.type == "key": |
|
|
mask[Actions.pickup] = 1.0 |
|
|
mask[Actions.pickup] = 1.0 |
|
|
|
|
|
|
|
|
# if self.env.carrying: |
|
|
|
|
|
# mask[Actions.drop] = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
if front_tile and front_tile.type == "door": |
|
|
if front_tile and front_tile.type == "door": |
|
|
mask[Actions.toggle] = 1.0 |
|
|
mask[Actions.toggle] = 1.0 |
|
@ -148,7 +150,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
if self.create_shield_at_reset and self.mask_actions: |
|
|
if self.create_shield_at_reset and self.mask_actions: |
|
|
self.shield = self.shield_creator.create_shield(env=self.env) |
|
|
self.shield = self.shield_creator.create_shield(env=self.env) |
|
|
|
|
|
|
|
|
self.keys = extract_keys(self.env) |
|
|
|
|
|
mask = self.create_action_mask() |
|
|
mask = self.create_action_mask() |
|
|
return { |
|
|
return { |
|
|
"data": obs["image"], |
|
|
"data": obs["image"], |
|
@ -164,7 +165,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
"action_mask": mask, |
|
|
"action_mask": mask, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#print(F"Info is {info}") |
|
|
|
|
|
return obs, rew, done, truncated, info |
|
|
return obs, rew, done, truncated, info |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -222,10 +222,8 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
def reset(self, *, seed=None, options=None): |
|
|
def reset(self, *, seed=None, options=None): |
|
|
obs, infos = self.env.reset(seed=seed, options=options) |
|
|
obs, infos = self.env.reset(seed=seed, options=options) |
|
|
|
|
|
|
|
|
keys = extract_keys(self.env) |
|
|
|
|
|
shield = self.shield_creator.create_shield(env=self.env) |
|
|
shield = self.shield_creator.create_shield(env=self.env) |
|
|
|
|
|
|
|
|
self.keys = keys |
|
|
|
|
|
self.shield = shield |
|
|
self.shield = shield |
|
|
return obs["image"], infos |
|
|
return obs["image"], infos |
|
|
|
|
|
|
|
|