import gymnasium as gym import numpy as np import random from minigrid.core.actions import Actions from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot from helpers import get_action_index_mapping from shieldhandlers import ShieldHandler class OneHotShieldingWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) self.framestack = framestack # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. # +4: Direction. self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 self.init_x = None self.init_y = None self.x_positions = [] self.y_positions = [] self.x_y_delta_buffer = deque(maxlen=100) self.vector_index = vector_index self.frame_buffer = deque(maxlen=self.framestack) for _ in range(self.framestack): self.frame_buffer.append(np.zeros((self.single_frame_dim,))) self.observation_space = Dict( { "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), } ) def observation(self, obs): # Debug output: max-x/y positions to watch exploration progress. # print(F"Initial observation in Wrapper {obs}") if self.step_count == 0: for _ in range(self.framestack): self.frame_buffer.append(np.zeros((self.single_frame_dim,))) if self.vector_index == 0: if self.x_positions: max_diff = max( np.sqrt( (np.array(self.x_positions) - self.init_x) ** 2 + (np.array(self.y_positions) - self.init_y) ** 2 ) ) self.x_y_delta_buffer.append(max_diff) print( "100-average dist travelled={}".format( np.mean(self.x_y_delta_buffer) ) ) self.x_positions = [] self.y_positions = [] self.init_x = self.agent_pos[0] self.init_y = self.agent_pos[1] self.x_positions.append(self.agent_pos[0]) self.y_positions.append(self.agent_pos[1]) image = obs["data"] # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) colors = one_hot(image[:, :, 1], depth=len(COLORS)) states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) all_ = np.concatenate([objects, colors, states], -1) all_flat = np.reshape(all_, (-1,)) direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) single_frame = np.concatenate([all_flat, direction]) self.frame_buffer.append(single_frame) tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } return tmp class MiniGridShieldingWrapper(gym.core.Wrapper): def __init__(self, env, shield_creator : ShieldHandler, shield_query_creator, create_shield_at_reset=True, mask_actions=True): super(MiniGridShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( { "data": env.observation_space.spaces["image"], "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) self.shield_creator = shield_creator self.create_shield_at_reset = create_shield_at_reset self.shield = shield_creator.create_shield(env=self.env) self.mask_actions = mask_actions self.shield_query_creator = shield_query_creator print(F"Shielding is {self.mask_actions}") def create_action_mask(self): # print(F"{self.mask_actions} No shielding") if not self.mask_actions: # print("No shielding") ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) # print(ret) return ret cur_pos_str = self.shield_query_creator(self.env) # print(F"Pos string {cur_pos_str}") # print(F"Shield {list(self.shield.keys())[0]}") # print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}") # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid allowed_actions = [] mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) has_allowed_actions = False for allowed_action in allowed_actions: index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set if index is None: print(F"No mapping for action {list(allowed_action.labels)}") print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") assert(False) allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] if allowed_action.prob == 0 and allowed: assert False if allowed: has_allowed_actions = True mask[index] = allowed # if not has_allowed_actions: # print(F"No action allowed for pos string {cur_pos_str}") # assert(False) else: for index, x in enumerate(mask): mask[index] = 1.0 front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) if front_tile is not None and front_tile.type == "key": mask[Actions.pickup] = 1.0 if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 # print(F"Mask is {mask} State: {cur_pos_str}") return mask def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) if self.create_shield_at_reset and self.mask_actions: self.shield = self.shield_creator.create_shield(env=self.env) mask = self.create_action_mask() return { "data": obs["image"], "action_mask": mask }, infos def step(self, action): orig_obs, rew, done, truncated, info = self.env.step(action) mask = self.create_action_mask() obs = { "data": orig_obs["image"], "action_mask": mask, } return obs, rew, done, truncated, info class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, env, shield_creator : ShieldHandler, shield_query_creator, create_shield_at_reset = True, mask_actions=True, ): super(MiniGridSbShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = env.observation_space.spaces["image"] self.shield_creator = shield_creator self.mask_actions = mask_actions self.shield_query_creator = shield_query_creator print(F"Shielding is {self.mask_actions}") def create_action_mask(self): if not self.mask_actions: return np.array([1.0] * self.max_available_actions, dtype=np.int8) cur_pos_str = self.shield_query_creator(self.env) allowed_actions = [] # Create the mask # If shield restricts actions, mask only valid actions with 1.0 # else set all actions valid mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] for allowed_action in allowed_actions: index = get_action_index_mapping(allowed_action.labels) if index is None: assert(False) mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] else: for index, x in enumerate(mask): mask[index] = 1.0 front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 return mask def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) shield = self.shield_creator.create_shield(env=self.env) self.shield = shield return obs["image"], infos def step(self, action): orig_obs, rew, done, truncated, info = self.env.step(action) obs = orig_obs["image"] return obs, rew, done, truncated, info