Thomas Knoll
11 months ago
8 changed files with 420 additions and 182 deletions
-
46examples/shields/rl/11_minigridrl.py
-
33examples/shields/rl/12_minigridrl_tune.py
-
23examples/shields/rl/13_minigridsb.py
-
34examples/shields/rl/14_train_eval.py
-
65examples/shields/rl/15_train_eval_tune.py
-
209examples/shields/rl/rllibutils.py
-
68examples/shields/rl/sb3utils.py
-
124examples/shields/rl/utils.py
@ -0,0 +1,209 @@ |
|||||
|
import gymnasium as gym |
||||
|
import numpy as np |
||||
|
import random |
||||
|
|
||||
|
from minigrid.core.actions import Actions |
||||
|
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX |
||||
|
|
||||
|
from gymnasium.spaces import Dict, Box |
||||
|
from collections import deque |
||||
|
from ray.rllib.utils.numpy import one_hot |
||||
|
|
||||
|
from helpers import get_action_index_mapping |
||||
|
from shieldhandlers import ShieldHandler |
||||
|
|
||||
|
|
||||
|
class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
||||
|
def __init__(self, env, vector_index, framestack): |
||||
|
super().__init__(env) |
||||
|
self.framestack = framestack |
||||
|
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. |
||||
|
# +4: Direction. |
||||
|
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 |
||||
|
self.init_x = None |
||||
|
self.init_y = None |
||||
|
self.x_positions = [] |
||||
|
self.y_positions = [] |
||||
|
self.x_y_delta_buffer = deque(maxlen=100) |
||||
|
self.vector_index = vector_index |
||||
|
self.frame_buffer = deque(maxlen=self.framestack) |
||||
|
for _ in range(self.framestack): |
||||
|
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
||||
|
|
||||
|
self.observation_space = Dict( |
||||
|
{ |
||||
|
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), |
||||
|
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), |
||||
|
} |
||||
|
) |
||||
|
|
||||
|
def observation(self, obs): |
||||
|
# Debug output: max-x/y positions to watch exploration progress. |
||||
|
# print(F"Initial observation in Wrapper {obs}") |
||||
|
if self.step_count == 0: |
||||
|
for _ in range(self.framestack): |
||||
|
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
||||
|
if self.vector_index == 0: |
||||
|
if self.x_positions: |
||||
|
max_diff = max( |
||||
|
np.sqrt( |
||||
|
(np.array(self.x_positions) - self.init_x) ** 2 |
||||
|
+ (np.array(self.y_positions) - self.init_y) ** 2 |
||||
|
) |
||||
|
) |
||||
|
self.x_y_delta_buffer.append(max_diff) |
||||
|
print( |
||||
|
"100-average dist travelled={}".format( |
||||
|
np.mean(self.x_y_delta_buffer) |
||||
|
) |
||||
|
) |
||||
|
self.x_positions = [] |
||||
|
self.y_positions = [] |
||||
|
self.init_x = self.agent_pos[0] |
||||
|
self.init_y = self.agent_pos[1] |
||||
|
|
||||
|
|
||||
|
self.x_positions.append(self.agent_pos[0]) |
||||
|
self.y_positions.append(self.agent_pos[1]) |
||||
|
|
||||
|
image = obs["data"] |
||||
|
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. |
||||
|
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) |
||||
|
colors = one_hot(image[:, :, 1], depth=len(COLORS)) |
||||
|
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) |
||||
|
|
||||
|
all_ = np.concatenate([objects, colors, states], -1) |
||||
|
all_flat = np.reshape(all_, (-1,)) |
||||
|
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) |
||||
|
single_frame = np.concatenate([all_flat, direction]) |
||||
|
self.frame_buffer.append(single_frame) |
||||
|
|
||||
|
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } |
||||
|
return tmp |
||||
|
|
||||
|
|
||||
|
class MiniGridShieldingWrapper(gym.core.Wrapper): |
||||
|
def __init__(self, |
||||
|
env, |
||||
|
shield_creator : ShieldHandler, |
||||
|
shield_query_creator, |
||||
|
create_shield_at_reset=True, |
||||
|
mask_actions=True): |
||||
|
super(MiniGridShieldingWrapper, self).__init__(env) |
||||
|
self.max_available_actions = env.action_space.n |
||||
|
self.observation_space = Dict( |
||||
|
{ |
||||
|
"data": env.observation_space.spaces["image"], |
||||
|
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), |
||||
|
} |
||||
|
) |
||||
|
self.shield_creator = shield_creator |
||||
|
self.create_shield_at_reset = create_shield_at_reset |
||||
|
self.shield = shield_creator.create_shield(env=self.env) |
||||
|
self.mask_actions = mask_actions |
||||
|
self.shield_query_creator = shield_query_creator |
||||
|
print(F"Shielding is {self.mask_actions}") |
||||
|
|
||||
|
def create_action_mask(self): |
||||
|
if not self.mask_actions: |
||||
|
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) |
||||
|
return ret |
||||
|
|
||||
|
cur_pos_str = self.shield_query_creator(self.env) |
||||
|
|
||||
|
# Create the mask |
||||
|
# If shield restricts action mask only valid with 1.0 |
||||
|
# else set all actions as valid |
||||
|
allowed_actions = [] |
||||
|
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
||||
|
|
||||
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
||||
|
allowed_actions = self.shield[cur_pos_str] |
||||
|
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) |
||||
|
has_allowed_actions = False |
||||
|
|
||||
|
for allowed_action in allowed_actions: |
||||
|
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set |
||||
|
if index is None: |
||||
|
assert(False) |
||||
|
|
||||
|
allowed = 1.0 |
||||
|
has_allowed_actions = True |
||||
|
mask[index] = allowed |
||||
|
else: |
||||
|
for index, x in enumerate(mask): |
||||
|
mask[index] = 1.0 |
||||
|
|
||||
|
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) |
||||
|
|
||||
|
if front_tile is not None and front_tile.type == "key": |
||||
|
mask[Actions.pickup] = 1.0 |
||||
|
|
||||
|
|
||||
|
if front_tile and front_tile.type == "door": |
||||
|
mask[Actions.toggle] = 1.0 |
||||
|
# print(F"Mask is {mask} State: {cur_pos_str}") |
||||
|
return mask |
||||
|
|
||||
|
def reset(self, *, seed=None, options=None): |
||||
|
obs, infos = self.env.reset(seed=seed, options=options) |
||||
|
|
||||
|
if self.create_shield_at_reset and self.mask_actions: |
||||
|
self.shield = self.shield_creator.create_shield(env=self.env) |
||||
|
|
||||
|
mask = self.create_action_mask() |
||||
|
return { |
||||
|
"data": obs["image"], |
||||
|
"action_mask": mask |
||||
|
}, infos |
||||
|
|
||||
|
def step(self, action): |
||||
|
orig_obs, rew, done, truncated, info = self.env.step(action) |
||||
|
|
||||
|
mask = self.create_action_mask() |
||||
|
obs = { |
||||
|
"data": orig_obs["image"], |
||||
|
"action_mask": mask, |
||||
|
} |
||||
|
|
||||
|
return obs, rew, done, truncated, info |
||||
|
|
||||
|
|
||||
|
def shielding_env_creater(config): |
||||
|
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") |
||||
|
framestack = config.get("framestack", 4) |
||||
|
args = config.get("args", None) |
||||
|
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" |
||||
|
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" |
||||
|
shielding = config.get("shielding", False) |
||||
|
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, |
||||
|
grid_to_prism_path=args.grid_to_prism_binary_path, |
||||
|
prism_path=args.prism_path, |
||||
|
formula=args.formula, |
||||
|
shield_value=args.shield_value, |
||||
|
prism_config=args.prism_config, |
||||
|
shield_comparision=args.shield_comparision) |
||||
|
|
||||
|
probability_intended = args.probability_intended |
||||
|
probability_displacement = args.probability_displacement |
||||
|
|
||||
|
env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) |
||||
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) |
||||
|
|
||||
|
env = OneHotShieldingWrapper(env, |
||||
|
config.vector_index if hasattr(config, "vector_index") else 0, |
||||
|
framestack=framestack |
||||
|
) |
||||
|
|
||||
|
|
||||
|
return env |
||||
|
|
||||
|
|
||||
|
def register_minigrid_shielding_env(args): |
||||
|
env_name = "mini-grid-shielding" |
||||
|
register_env(env_name, shielding_env_creater) |
||||
|
|
||||
|
ModelCatalog.register_custom_model( |
||||
|
"shielding_model", |
||||
|
TorchActionMaskModel |
||||
|
) |
@ -0,0 +1,68 @@ |
|||||
|
import gymnasium as gym |
||||
|
import numpy as np |
||||
|
import random |
||||
|
|
||||
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
||||
|
def __init__(self, |
||||
|
env, |
||||
|
shield_creator : ShieldHandler, |
||||
|
shield_query_creator, |
||||
|
create_shield_at_reset = True, |
||||
|
mask_actions=True, |
||||
|
): |
||||
|
super(MiniGridSbShieldingWrapper, self).__init__(env) |
||||
|
self.max_available_actions = env.action_space.n |
||||
|
self.observation_space = env.observation_space.spaces["image"] |
||||
|
|
||||
|
self.shield_creator = shield_creator |
||||
|
self.mask_actions = mask_actions |
||||
|
self.shield_query_creator = shield_query_creator |
||||
|
|
||||
|
def create_action_mask(self): |
||||
|
if not self.mask_actions: |
||||
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
||||
|
|
||||
|
cur_pos_str = self.shield_query_creator(self.env) |
||||
|
|
||||
|
allowed_actions = [] |
||||
|
|
||||
|
# Create the mask |
||||
|
# If shield restricts actions, mask only valid actions with 1.0 |
||||
|
# else set all actions valid |
||||
|
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
||||
|
|
||||
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
||||
|
allowed_actions = self.shield[cur_pos_str] |
||||
|
for allowed_action in allowed_actions: |
||||
|
index = get_action_index_mapping(allowed_action.labels) |
||||
|
if index is None: |
||||
|
assert(False) |
||||
|
|
||||
|
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
||||
|
else: |
||||
|
for index, x in enumerate(mask): |
||||
|
mask[index] = 1.0 |
||||
|
|
||||
|
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) |
||||
|
|
||||
|
|
||||
|
if front_tile and front_tile.type == "door": |
||||
|
mask[Actions.toggle] = 1.0 |
||||
|
|
||||
|
return mask |
||||
|
|
||||
|
|
||||
|
def reset(self, *, seed=None, options=None): |
||||
|
obs, infos = self.env.reset(seed=seed, options=options) |
||||
|
|
||||
|
shield = self.shield_creator.create_shield(env=self.env) |
||||
|
|
||||
|
self.shield = shield |
||||
|
return obs["image"], infos |
||||
|
|
||||
|
def step(self, action): |
||||
|
orig_obs, rew, done, truncated, info = self.env.step(action) |
||||
|
obs = orig_obs["image"] |
||||
|
|
||||
|
return obs, rew, done, truncated, info |
||||
|
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue