Thomas Knoll
1 year ago
6 changed files with 13 additions and 428 deletions
-
23examples/shields/rl/callbacks.py
-
149examples/shields/rl/helpers.py
-
5examples/shields/rl/rllibutils.py
-
4examples/shields/rl/sb3utils.py
-
8examples/shields/rl/utils.py
-
252examples/shields/rl/wrappers.py
@ -1,149 +0,0 @@ |
|||||
import minigrid |
|
||||
from minigrid.core.actions import Actions |
|
||||
|
|
||||
from datetime import datetime |
|
||||
from enum import Enum |
|
||||
|
|
||||
import os |
|
||||
|
|
||||
import stormpy |
|
||||
import stormpy.core |
|
||||
import stormpy.simulator |
|
||||
|
|
||||
import stormpy.shields |
|
||||
import stormpy.logic |
|
||||
|
|
||||
import stormpy.examples |
|
||||
import stormpy.examples.files |
|
||||
|
|
||||
class ShieldingConfig(Enum): |
|
||||
Training = 'training' |
|
||||
Evaluation = 'evaluation' |
|
||||
Disabled = 'none' |
|
||||
Full = 'full' |
|
||||
|
|
||||
def __str__(self) -> str: |
|
||||
return self.value |
|
||||
|
|
||||
|
|
||||
def extract_keys(env): |
|
||||
keys = [] |
|
||||
for j in range(env.grid.height): |
|
||||
for i in range(env.grid.width): |
|
||||
obj = env.grid.get(i,j) |
|
||||
|
|
||||
if obj and obj.type == "key": |
|
||||
keys.append((obj, i, j)) |
|
||||
|
|
||||
if env.carrying and env.carrying.type == "key": |
|
||||
keys.append((env.carrying, -1, -1)) |
|
||||
# TODO Maybe need to add ordering of keys so it matches the order in the shield |
|
||||
return keys |
|
||||
|
|
||||
def extract_doors(env): |
|
||||
doors = [] |
|
||||
for j in range(env.grid.height): |
|
||||
for i in range(env.grid.width): |
|
||||
obj = env.grid.get(i,j) |
|
||||
|
|
||||
if obj and obj.type == "door": |
|
||||
doors.append(obj) |
|
||||
|
|
||||
return doors |
|
||||
|
|
||||
def extract_adversaries(env): |
|
||||
adv = [] |
|
||||
|
|
||||
if not hasattr(env, "adversaries"): |
|
||||
return [] |
|
||||
|
|
||||
for color, adversary in env.adversaries.items(): |
|
||||
adv.append(adversary) |
|
||||
|
|
||||
|
|
||||
return adv |
|
||||
|
|
||||
def create_log_dir(args): |
|
||||
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}" |
|
||||
|
|
||||
def test_name(args): |
|
||||
return F"{args.expname}" |
|
||||
|
|
||||
def get_action_index_mapping(actions): |
|
||||
for action_str in actions: |
|
||||
if not "Agent" in action_str: |
|
||||
continue |
|
||||
|
|
||||
if "move" in action_str: |
|
||||
return Actions.forward |
|
||||
elif "left" in action_str: |
|
||||
return Actions.left |
|
||||
elif "right" in action_str: |
|
||||
return Actions.right |
|
||||
elif "pickup" in action_str: |
|
||||
return Actions.pickup |
|
||||
elif "done" in action_str: |
|
||||
return Actions.done |
|
||||
elif "drop" in action_str: |
|
||||
return Actions.drop |
|
||||
elif "toggle" in action_str: |
|
||||
return Actions.toggle |
|
||||
elif "unlock" in action_str: |
|
||||
return Actions.toggle |
|
||||
|
|
||||
raise ValueError("No action mapping found") |
|
||||
|
|
||||
|
|
||||
def parse_arguments(argparse): |
|
||||
parser = argparse.ArgumentParser() |
|
||||
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") |
|
||||
parser.add_argument("--env", |
|
||||
help="gym environment to load", |
|
||||
default="MiniGrid-LavaSlipperyS12-v2", |
|
||||
choices=[ |
|
||||
"MiniGrid-Adv-8x8-v0", |
|
||||
"MiniGrid-AdvSimple-8x8-v0", |
|
||||
"MiniGrid-SingleDoor-7x6-v0", |
|
||||
"MiniGrid-LavaCrossingS9N1-v0", |
|
||||
"MiniGrid-LavaCrossingS9N3-v0", |
|
||||
"MiniGrid-LavaSlipperyS12-v0", |
|
||||
"MiniGrid-LavaSlipperyS12-v1", |
|
||||
"MiniGrid-LavaSlipperyS12-v2", |
|
||||
"MiniGrid-LavaSlipperyS12-v3", |
|
||||
"MiniGrid-DoorKey-8x8-v0", |
|
||||
# "MiniGrid-DoubleDoor-16x16-v0", |
|
||||
# "MiniGrid-DoubleDoor-12x12-v0", |
|
||||
# "MiniGrid-DoubleDoor-10x8-v0", |
|
||||
# "MiniGrid-LockedRoom-v0", |
|
||||
# "MiniGrid-FourRooms-v0", |
|
||||
# "MiniGrid-LavaGapS7-v0", |
|
||||
# "MiniGrid-SimpleCrossingS9N3-v0", |
|
||||
# "MiniGrid-DoorKey-16x16-v0", |
|
||||
# "MiniGrid-Empty-Random-6x6-v0", |
|
||||
]) |
|
||||
|
|
||||
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) |
|
||||
parser.add_argument("--grid_to_prism_binary_path", default="./main") |
|
||||
parser.add_argument("--grid_path", default="grid") |
|
||||
parser.add_argument("--prism_path", default="grid") |
|
||||
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) |
|
||||
parser.add_argument("--log_dir", default="../log_results/") |
|
||||
parser.add_argument("--evaluations", type=int, default=30 ) |
|
||||
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" |
|
||||
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") |
|
||||
parser.add_argument("--workers", type=int, default=1) |
|
||||
parser.add_argument("--num_gpus", type=float, default=0) |
|
||||
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) |
|
||||
parser.add_argument("--steps", default=20_000, type=int) |
|
||||
parser.add_argument("--expname", default="exp") |
|
||||
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) |
|
||||
parser.add_argument("--prism_config", default=None) |
|
||||
parser.add_argument("--shield_value", default=0.9, type=float) |
|
||||
parser.add_argument("--prob_direct", default=1/4, type=float) |
|
||||
parser.add_argument("--prob_forward", default=3/4, type=float) |
|
||||
parser.add_argument("--prob_next", default=1/8, type=float) |
|
||||
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) |
|
||||
# parser.add_argument("--random_starts", default=1, type=int) |
|
||||
args = parser.parse_args() |
|
||||
|
|
||||
return args |
|
@ -1,252 +0,0 @@ |
|||||
import gymnasium as gym |
|
||||
import numpy as np |
|
||||
import random |
|
||||
|
|
||||
from minigrid.core.actions import Actions |
|
||||
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX |
|
||||
|
|
||||
from gymnasium.spaces import Dict, Box |
|
||||
from collections import deque |
|
||||
from ray.rllib.utils.numpy import one_hot |
|
||||
|
|
||||
from helpers import get_action_index_mapping |
|
||||
from shieldhandlers import ShieldHandler |
|
||||
|
|
||||
|
|
||||
class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
||||
def __init__(self, env, vector_index, framestack): |
|
||||
super().__init__(env) |
|
||||
self.framestack = framestack |
|
||||
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. |
|
||||
# +4: Direction. |
|
||||
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 |
|
||||
self.init_x = None |
|
||||
self.init_y = None |
|
||||
self.x_positions = [] |
|
||||
self.y_positions = [] |
|
||||
self.x_y_delta_buffer = deque(maxlen=100) |
|
||||
self.vector_index = vector_index |
|
||||
self.frame_buffer = deque(maxlen=self.framestack) |
|
||||
for _ in range(self.framestack): |
|
||||
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
|
||||
|
|
||||
self.observation_space = Dict( |
|
||||
{ |
|
||||
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), |
|
||||
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), |
|
||||
} |
|
||||
) |
|
||||
|
|
||||
def observation(self, obs): |
|
||||
# Debug output: max-x/y positions to watch exploration progress. |
|
||||
# print(F"Initial observation in Wrapper {obs}") |
|
||||
if self.step_count == 0: |
|
||||
for _ in range(self.framestack): |
|
||||
self.frame_buffer.append(np.zeros((self.single_frame_dim,))) |
|
||||
if self.vector_index == 0: |
|
||||
if self.x_positions: |
|
||||
max_diff = max( |
|
||||
np.sqrt( |
|
||||
(np.array(self.x_positions) - self.init_x) ** 2 |
|
||||
+ (np.array(self.y_positions) - self.init_y) ** 2 |
|
||||
) |
|
||||
) |
|
||||
self.x_y_delta_buffer.append(max_diff) |
|
||||
print( |
|
||||
"100-average dist travelled={}".format( |
|
||||
np.mean(self.x_y_delta_buffer) |
|
||||
) |
|
||||
) |
|
||||
self.x_positions = [] |
|
||||
self.y_positions = [] |
|
||||
self.init_x = self.agent_pos[0] |
|
||||
self.init_y = self.agent_pos[1] |
|
||||
|
|
||||
|
|
||||
self.x_positions.append(self.agent_pos[0]) |
|
||||
self.y_positions.append(self.agent_pos[1]) |
|
||||
|
|
||||
image = obs["data"] |
|
||||
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. |
|
||||
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) |
|
||||
colors = one_hot(image[:, :, 1], depth=len(COLORS)) |
|
||||
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) |
|
||||
|
|
||||
all_ = np.concatenate([objects, colors, states], -1) |
|
||||
all_flat = np.reshape(all_, (-1,)) |
|
||||
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) |
|
||||
single_frame = np.concatenate([all_flat, direction]) |
|
||||
self.frame_buffer.append(single_frame) |
|
||||
|
|
||||
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } |
|
||||
return tmp |
|
||||
|
|
||||
|
|
||||
class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
||||
def __init__(self, |
|
||||
env, |
|
||||
shield_creator : ShieldHandler, |
|
||||
shield_query_creator, |
|
||||
create_shield_at_reset=True, |
|
||||
mask_actions=True): |
|
||||
super(MiniGridShieldingWrapper, self).__init__(env) |
|
||||
self.max_available_actions = env.action_space.n |
|
||||
self.observation_space = Dict( |
|
||||
{ |
|
||||
"data": env.observation_space.spaces["image"], |
|
||||
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), |
|
||||
} |
|
||||
) |
|
||||
self.shield_creator = shield_creator |
|
||||
self.create_shield_at_reset = create_shield_at_reset |
|
||||
self.shield = shield_creator.create_shield(env=self.env) |
|
||||
self.mask_actions = mask_actions |
|
||||
self.shield_query_creator = shield_query_creator |
|
||||
print(F"Shielding is {self.mask_actions}") |
|
||||
|
|
||||
def create_action_mask(self): |
|
||||
# print(F"{self.mask_actions} No shielding") |
|
||||
if not self.mask_actions: |
|
||||
# print("No shielding") |
|
||||
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
||||
# print(ret) |
|
||||
return ret |
|
||||
|
|
||||
cur_pos_str = self.shield_query_creator(self.env) |
|
||||
# print(F"Pos string {cur_pos_str}") |
|
||||
# print(F"Shield {list(self.shield.keys())[0]}") |
|
||||
# print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}") |
|
||||
# Create the mask |
|
||||
# If shield restricts action mask only valid with 1.0 |
|
||||
# else set all actions as valid |
|
||||
allowed_actions = [] |
|
||||
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
|
||||
|
|
||||
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
||||
allowed_actions = self.shield[cur_pos_str] |
|
||||
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) |
|
||||
has_allowed_actions = False |
|
||||
|
|
||||
for allowed_action in allowed_actions: |
|
||||
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set |
|
||||
if index is None: |
|
||||
print(F"No mapping for action {list(allowed_action.labels)}") |
|
||||
print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") |
|
||||
assert(False) |
|
||||
|
|
||||
allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
||||
if allowed_action.prob == 0 and allowed: |
|
||||
assert False |
|
||||
if allowed: |
|
||||
has_allowed_actions = True |
|
||||
mask[index] = allowed |
|
||||
|
|
||||
# if not has_allowed_actions: |
|
||||
# print(F"No action allowed for pos string {cur_pos_str}") |
|
||||
# assert(False) |
|
||||
|
|
||||
else: |
|
||||
for index, x in enumerate(mask): |
|
||||
mask[index] = 1.0 |
|
||||
|
|
||||
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) |
|
||||
|
|
||||
if front_tile is not None and front_tile.type == "key": |
|
||||
mask[Actions.pickup] = 1.0 |
|
||||
|
|
||||
|
|
||||
if front_tile and front_tile.type == "door": |
|
||||
mask[Actions.toggle] = 1.0 |
|
||||
# print(F"Mask is {mask} State: {cur_pos_str}") |
|
||||
return mask |
|
||||
|
|
||||
def reset(self, *, seed=None, options=None): |
|
||||
obs, infos = self.env.reset(seed=seed, options=options) |
|
||||
|
|
||||
if self.create_shield_at_reset and self.mask_actions: |
|
||||
self.shield = self.shield_creator.create_shield(env=self.env) |
|
||||
|
|
||||
mask = self.create_action_mask() |
|
||||
return { |
|
||||
"data": obs["image"], |
|
||||
"action_mask": mask |
|
||||
}, infos |
|
||||
|
|
||||
def step(self, action): |
|
||||
orig_obs, rew, done, truncated, info = self.env.step(action) |
|
||||
|
|
||||
mask = self.create_action_mask() |
|
||||
obs = { |
|
||||
"data": orig_obs["image"], |
|
||||
"action_mask": mask, |
|
||||
} |
|
||||
|
|
||||
return obs, rew, done, truncated, info |
|
||||
|
|
||||
|
|
||||
|
|
||||
class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
||||
def __init__(self, |
|
||||
env, |
|
||||
shield_creator : ShieldHandler, |
|
||||
shield_query_creator, |
|
||||
create_shield_at_reset = True, |
|
||||
mask_actions=True, |
|
||||
): |
|
||||
super(MiniGridSbShieldingWrapper, self).__init__(env) |
|
||||
self.max_available_actions = env.action_space.n |
|
||||
self.observation_space = env.observation_space.spaces["image"] |
|
||||
|
|
||||
self.shield_creator = shield_creator |
|
||||
self.mask_actions = mask_actions |
|
||||
self.shield_query_creator = shield_query_creator |
|
||||
print(F"Shielding is {self.mask_actions}") |
|
||||
|
|
||||
def create_action_mask(self): |
|
||||
if not self.mask_actions: |
|
||||
return np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
||||
|
|
||||
cur_pos_str = self.shield_query_creator(self.env) |
|
||||
|
|
||||
allowed_actions = [] |
|
||||
|
|
||||
# Create the mask |
|
||||
# If shield restricts actions, mask only valid actions with 1.0 |
|
||||
# else set all actions valid |
|
||||
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) |
|
||||
|
|
||||
if cur_pos_str in self.shield and self.shield[cur_pos_str]: |
|
||||
allowed_actions = self.shield[cur_pos_str] |
|
||||
for allowed_action in allowed_actions: |
|
||||
index = get_action_index_mapping(allowed_action.labels) |
|
||||
if index is None: |
|
||||
assert(False) |
|
||||
|
|
||||
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] |
|
||||
else: |
|
||||
for index, x in enumerate(mask): |
|
||||
mask[index] = 1.0 |
|
||||
|
|
||||
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) |
|
||||
|
|
||||
|
|
||||
if front_tile and front_tile.type == "door": |
|
||||
mask[Actions.toggle] = 1.0 |
|
||||
|
|
||||
return mask |
|
||||
|
|
||||
|
|
||||
def reset(self, *, seed=None, options=None): |
|
||||
obs, infos = self.env.reset(seed=seed, options=options) |
|
||||
|
|
||||
shield = self.shield_creator.create_shield(env=self.env) |
|
||||
|
|
||||
self.shield = shield |
|
||||
return obs["image"], infos |
|
||||
|
|
||||
def step(self, action): |
|
||||
orig_obs, rew, done, truncated, info = self.env.step(action) |
|
||||
obs = orig_obs["image"] |
|
||||
|
|
||||
return obs, rew, done, truncated, info |
|
||||
|
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue