diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 7e5660d..2d7aa02 100755 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -9,48 +9,12 @@ from ray.rllib.models import ModelCatalog from torch_action_mask_model import TorchActionMaskModel -from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir, ShieldingConfig -from shieldhandlers import MiniGridShieldHandler, create_shield_query -from callbacks import MyCallbacks +from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater +from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig +from callbacks import CustomCallback from ray.tune.logger import TBXLogger -def shielding_env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") - framestack = config.get("framestack", 4) - args = config.get("args", None) - args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt" - args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism" - - prob_forward = args.prob_forward - prob_direct = args.prob_direct - prob_next = args.prob_next - - shield_creator = MiniGridShieldHandler(args.grid_path, - args.grid_to_prism_binary_path, - args.prism_path, - args.formula, - args.shield_value, - args.prism_config, - shield_comparision=args.shield_comparision) - - env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, - shield_query_creator=create_shield_query, - mask_actions=args.shielding != ShieldingConfig.Disabled, - create_shield_at_reset=args.shield_creation_at_reset) - # env = minigrid.wrappers.ImgObsWrapper(env) - # env = ImgObsWrapper(env) - env = OneHotShieldingWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - - return env - - def register_minigrid_shielding_env(args): env_name = "mini-grid-shielding" @@ -71,7 +35,7 @@ def ppo(args): .resources(num_gpus=0) .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .framework("torch") - .callbacks(MyCallbacks) + .callbacks(CustomCallback) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": TBXLogger, @@ -109,7 +73,7 @@ def dqn(args): config = config.rollouts(num_rollout_workers=args.workers) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.framework("torch") - config = config.callbacks(MyCallbacks) + config = config.callbacks(CustomCallback) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ "type": TBXLogger, diff --git a/examples/shields/rl/12_minigridrl_tune.py b/examples/shields/rl/12_minigridrl_tune.py index e0ff945..e64d609 100644 --- a/examples/shields/rl/12_minigridrl_tune.py +++ b/examples/shields/rl/12_minigridrl_tune.py @@ -11,36 +11,13 @@ from ray.rllib.models import ModelCatalog from torch_action_mask_model import TorchActionMaskModel -from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir, ShieldingConfig -from shieldhandlers import MiniGridShieldHandler, create_shield_query -from callbacks import MyCallbacks +from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper +from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig +from callbacks import CustomCallback from torch.utils.tensorboard import SummaryWriter from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger -def shielding_env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") - framestack = config.get("framestack", 4) - args = config.get("args", None) - args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" - args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" - - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) - - env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) - # env = minigrid.wrappers.ImgObsWrapper(env) - # env = ImgObsWrapper(env) - env = OneHotShieldingWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - - return env - - def register_minigrid_shielding_env(args): env_name = "mini-grid-shielding" @@ -60,7 +37,7 @@ def ppo(args): .resources(num_gpus=0) .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .framework("torch") - .callbacks(MyCallbacks) + .callbacks(CustomCallback) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": TBXLogger, @@ -83,7 +60,7 @@ def dqn(args): config = config.rollouts(num_rollout_workers=args.workers) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.framework("torch") - config = config.callbacks(MyCallbacks) + config = config.callbacks(CustomCallback) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ "type": TBXLogger, diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 43575dc..f04c65e 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -2,7 +2,6 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.wrappers import ActionMasker -from stable_baselines3.common.callbacks import BaseCallback import gymnasium as gym @@ -10,28 +9,13 @@ from minigrid.core.actions import Actions import time -from helpers import parse_arguments, create_log_dir, ShieldingConfig -from shieldhandlers import MiniGridShieldHandler, create_shield_query -from wrappers import MiniGridSbShieldingWrapper - -class CustomCallback(BaseCallback): - def __init__(self, verbose: int = 0, env=None): - super(CustomCallback, self).__init__(verbose) - self.env = env - - - def _on_step(self) -> bool: - print(self.env.printGrid()) - return super()._on_step() - - - +from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig +from sb3utils import MiniGridSbShieldingWrapper def mask_fn(env: gym.Env): return env.create_action_mask() - def main(): import argparse args = parse_arguments(argparse) @@ -44,13 +28,12 @@ def main(): env = gym.make(args.env, render_mode="rgb_array") env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) - callback = CustomCallback(1, env) model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) steps = args.steps - model.learn(steps, callback=callback) + model.learn(steps) #W mean_reward, std_reward = evaluate_policy(model, model.get_env()) diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py index dae1c77..56913a6 100644 --- a/examples/shields/rl/14_train_eval.py +++ b/examples/shields/rl/14_train_eval.py @@ -8,39 +8,13 @@ from ray.rllib.models import ModelCatalog from torch_action_mask_model import TorchActionMaskModel -from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir, ShieldingConfig -from shieldhandlers import MiniGridShieldHandler, create_shield_query +from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater +from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig -from callbacks import MyCallbacks +from callbacks import CustomCallback from torch.utils.tensorboard import SummaryWriter - - - -def shielding_env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") - framestack = config.get("framestack", 4) - args = config.get("args", None) - args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" - args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" - - shielding = config.get("shielding", False) - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) - - env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) - - env = OneHotShieldingWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - - return env - - def register_minigrid_shielding_env(args): env_name = "mini-grid-shielding" register_env(env_name, shielding_env_creater) @@ -60,7 +34,7 @@ def ppo(args): .environment( env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .framework("torch") - .callbacks(MyCallbacks) + .callbacks(CustomCallback) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10, diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 9220af5..ad9d5c1 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -14,43 +14,11 @@ from ray.rllib.algorithms.callbacks import make_multi_callbacks from ray.air import session from torch_action_mask_model import TorchActionMaskModel -from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name -from shieldhandlers import MiniGridShieldHandler, create_shield_query +from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater +from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig, test_name from torch.utils.tensorboard import SummaryWriter -from callbacks import MyCallbacks - - -def shielding_env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") - framestack = config.get("framestack", 4) - args = config.get("args", None) - args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" - args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" - shielding = config.get("shielding", False) - shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, - grid_to_prism_path=args.grid_to_prism_binary_path, - prism_path=args.prism_path, - formula=args.formula, - shield_value=args.shield_value, - prism_config=args.prism_config, - shield_comparision=args.shield_comparision) - - prob_forward = args.prob_forward - prob_direct = args.prob_direct - prob_next = args.prob_next - - env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) - - env = OneHotShieldingWrapper(env, - config.vector_index if hasattr(config, "vector_index") else 0, - framestack=framestack - ) - - - return env +from callbacks import CustomCallback def register_minigrid_shielding_env(args): @@ -79,7 +47,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks(MyCallbacks) + .callbacks(CustomCallback) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10, @@ -133,31 +101,6 @@ def ppo(args): ] pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print}) - # algo = Algorithm.from_checkpoint(best_result.checkpoint) - - - # eval_log_dir = F"{logdir}-eval" - - # writer = SummaryWriter(log_dir=eval_log_dir) - # csv_logger = CSVLogger(config=config, logdir=eval_log_dir) - - - # for i in range(args.evaluations): - # eval_result = algo.evaluate() - # print(pretty_print(eval_result)) - # print(eval_result) - # # logger.on_result(eval_result) - - # csv_logger.on_result(eval_result) - - # evaluation = eval_result['evaluation'] - # epsiode_reward_mean = evaluation['episode_reward_mean'] - # episode_len_mean = evaluation['episode_len_mean'] - # print(epsiode_reward_mean) - # writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) - # writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) - - def main(): ray.init(num_cpus=3) import argparse diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py new file mode 100644 index 0000000..03b8253 --- /dev/null +++ b/examples/shields/rl/rllibutils.py @@ -0,0 +1,209 @@ +import gymnasium as gym +import numpy as np +import random + +from minigrid.core.actions import Actions +from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX + +from gymnasium.spaces import Dict, Box +from collections import deque +from ray.rllib.utils.numpy import one_hot + +from helpers import get_action_index_mapping +from shieldhandlers import ShieldHandler + + +class OneHotShieldingWrapper(gym.core.ObservationWrapper): + def __init__(self, env, vector_index, framestack): + super().__init__(env) + self.framestack = framestack + # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. + # +4: Direction. + self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 + self.init_x = None + self.init_y = None + self.x_positions = [] + self.y_positions = [] + self.x_y_delta_buffer = deque(maxlen=100) + self.vector_index = vector_index + self.frame_buffer = deque(maxlen=self.framestack) + for _ in range(self.framestack): + self.frame_buffer.append(np.zeros((self.single_frame_dim,))) + + self.observation_space = Dict( + { + "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), + "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), + } + ) + + def observation(self, obs): + # Debug output: max-x/y positions to watch exploration progress. + # print(F"Initial observation in Wrapper {obs}") + if self.step_count == 0: + for _ in range(self.framestack): + self.frame_buffer.append(np.zeros((self.single_frame_dim,))) + if self.vector_index == 0: + if self.x_positions: + max_diff = max( + np.sqrt( + (np.array(self.x_positions) - self.init_x) ** 2 + + (np.array(self.y_positions) - self.init_y) ** 2 + ) + ) + self.x_y_delta_buffer.append(max_diff) + print( + "100-average dist travelled={}".format( + np.mean(self.x_y_delta_buffer) + ) + ) + self.x_positions = [] + self.y_positions = [] + self.init_x = self.agent_pos[0] + self.init_y = self.agent_pos[1] + + + self.x_positions.append(self.agent_pos[0]) + self.y_positions.append(self.agent_pos[1]) + + image = obs["data"] + # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. + objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) + colors = one_hot(image[:, :, 1], depth=len(COLORS)) + states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) + + all_ = np.concatenate([objects, colors, states], -1) + all_flat = np.reshape(all_, (-1,)) + direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) + single_frame = np.concatenate([all_flat, direction]) + self.frame_buffer.append(single_frame) + + tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } + return tmp + + +class MiniGridShieldingWrapper(gym.core.Wrapper): + def __init__(self, + env, + shield_creator : ShieldHandler, + shield_query_creator, + create_shield_at_reset=True, + mask_actions=True): + super(MiniGridShieldingWrapper, self).__init__(env) + self.max_available_actions = env.action_space.n + self.observation_space = Dict( + { + "data": env.observation_space.spaces["image"], + "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), + } + ) + self.shield_creator = shield_creator + self.create_shield_at_reset = create_shield_at_reset + self.shield = shield_creator.create_shield(env=self.env) + self.mask_actions = mask_actions + self.shield_query_creator = shield_query_creator + print(F"Shielding is {self.mask_actions}") + + def create_action_mask(self): + if not self.mask_actions: + ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) + return ret + + cur_pos_str = self.shield_query_creator(self.env) + + # Create the mask + # If shield restricts action mask only valid with 1.0 + # else set all actions as valid + allowed_actions = [] + mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + + if cur_pos_str in self.shield and self.shield[cur_pos_str]: + allowed_actions = self.shield[cur_pos_str] + zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) + has_allowed_actions = False + + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set + if index is None: + assert(False) + + allowed = 1.0 + has_allowed_actions = True + mask[index] = allowed + else: + for index, x in enumerate(mask): + mask[index] = 1.0 + + front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + + if front_tile is not None and front_tile.type == "key": + mask[Actions.pickup] = 1.0 + + + if front_tile and front_tile.type == "door": + mask[Actions.toggle] = 1.0 + # print(F"Mask is {mask} State: {cur_pos_str}") + return mask + + def reset(self, *, seed=None, options=None): + obs, infos = self.env.reset(seed=seed, options=options) + + if self.create_shield_at_reset and self.mask_actions: + self.shield = self.shield_creator.create_shield(env=self.env) + + mask = self.create_action_mask() + return { + "data": obs["image"], + "action_mask": mask + }, infos + + def step(self, action): + orig_obs, rew, done, truncated, info = self.env.step(action) + + mask = self.create_action_mask() + obs = { + "data": orig_obs["image"], + "action_mask": mask, + } + + return obs, rew, done, truncated, info + + +def shielding_env_creater(config): + name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") + framestack = config.get("framestack", 4) + args = config.get("args", None) + args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" + shielding = config.get("shielding", False) + shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, + grid_to_prism_path=args.grid_to_prism_binary_path, + prism_path=args.prism_path, + formula=args.formula, + shield_value=args.shield_value, + prism_config=args.prism_config, + shield_comparision=args.shield_comparision) + + probability_intended = args.probability_intended + probability_displacement = args.probability_displacement + + env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) + + env = OneHotShieldingWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + + return env + + +def register_minigrid_shielding_env(args): + env_name = "mini-grid-shielding" + register_env(env_name, shielding_env_creater) + + ModelCatalog.register_custom_model( + "shielding_model", + TorchActionMaskModel + ) diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py new file mode 100644 index 0000000..f0798d2 --- /dev/null +++ b/examples/shields/rl/sb3utils.py @@ -0,0 +1,68 @@ +import gymnasium as gym +import numpy as np +import random + +class MiniGridSbShieldingWrapper(gym.core.Wrapper): + def __init__(self, + env, + shield_creator : ShieldHandler, + shield_query_creator, + create_shield_at_reset = True, + mask_actions=True, + ): + super(MiniGridSbShieldingWrapper, self).__init__(env) + self.max_available_actions = env.action_space.n + self.observation_space = env.observation_space.spaces["image"] + + self.shield_creator = shield_creator + self.mask_actions = mask_actions + self.shield_query_creator = shield_query_creator + + def create_action_mask(self): + if not self.mask_actions: + return np.array([1.0] * self.max_available_actions, dtype=np.int8) + + cur_pos_str = self.shield_query_creator(self.env) + + allowed_actions = [] + + # Create the mask + # If shield restricts actions, mask only valid actions with 1.0 + # else set all actions valid + mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + + if cur_pos_str in self.shield and self.shield[cur_pos_str]: + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action.labels) + if index is None: + assert(False) + + mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] + else: + for index, x in enumerate(mask): + mask[index] = 1.0 + + front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + + + if front_tile and front_tile.type == "door": + mask[Actions.toggle] = 1.0 + + return mask + + + def reset(self, *, seed=None, options=None): + obs, infos = self.env.reset(seed=seed, options=options) + + shield = self.shield_creator.create_shield(env=self.env) + + self.shield = shield + return obs["image"], infos + + def step(self, action): + orig_obs, rew, done, truncated, info = self.env.step(action) + obs = orig_obs["image"] + + return obs, rew, done, truncated, info + diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/utils.py similarity index 59% rename from examples/shields/rl/shieldhandlers.py rename to examples/shields/rl/utils.py index 95cab72..cab65be 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/utils.py @@ -78,7 +78,6 @@ class MiniGridShieldHandler(ShieldHandler): assert result.has_shield shield = result.shield - stormpy.shields.export_shield(model, shield, "Grid.shield") action_dictionary = {} shield_scheduler = shield.construct() state_valuations = model.state_valuations @@ -193,4 +192,125 @@ def create_shield_query(env): query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" return query - \ No newline at end of file + + +class ShieldingConfig(Enum): + Training = 'training' + Evaluation = 'evaluation' + Disabled = 'none' + Full = 'full' + + def __str__(self) -> str: + return self.value + + +def extract_keys(env): + keys = [] + for j in range(env.grid.height): + for i in range(env.grid.width): + obj = env.grid.get(i,j) + + if obj and obj.type == "key": + keys.append((obj, i, j)) + + if env.carrying and env.carrying.type == "key": + keys.append((env.carrying, -1, -1)) + # TODO Maybe need to add ordering of keys so it matches the order in the shield + return keys + +def extract_doors(env): + doors = [] + for j in range(env.grid.height): + for i in range(env.grid.width): + obj = env.grid.get(i,j) + + if obj and obj.type == "door": + doors.append(obj) + + return doors + +def extract_adversaries(env): + adv = [] + + if not hasattr(env, "adversaries"): + return [] + + for color, adversary in env.adversaries.items(): + adv.append(adversary) + + + return adv + +def create_log_dir(args): + return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}" + +def test_name(args): + return F"{args.expname}" + +def get_action_index_mapping(actions): + for action_str in actions: + if not "Agent" in action_str: + continue + + if "move" in action_str: + return Actions.forward + elif "left" in action_str: + return Actions.left + elif "right" in action_str: + return Actions.right + elif "pickup" in action_str: + return Actions.pickup + elif "done" in action_str: + return Actions.done + elif "drop" in action_str: + return Actions.drop + elif "toggle" in action_str: + return Actions.toggle + elif "unlock" in action_str: + return Actions.toggle + + raise ValueError("No action mapping found") + + +def parse_arguments(argparse): + parser = argparse.ArgumentParser() + # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") + parser.add_argument("--env", + help="gym environment to load", + default="MiniGrid-LavaSlipperyS12-v2", + choices=[ + "MiniGrid-Adv-8x8-v0", + "MiniGrid-AdvSimple-8x8-v0", + "MiniGrid-LavaCrossingS9N1-v0", + "MiniGrid-LavaCrossingS9N3-v0", + "MiniGrid-LavaSlipperyS12-v0", + "MiniGrid-LavaSlipperyS12-v1", + "MiniGrid-LavaSlipperyS12-v2", + "MiniGrid-LavaSlipperyS12-v3", + + ]) + + # parser.add_argument("--seed", type=int, help="seed for environment", default=None) + parser.add_argument("--grid_to_prism_binary_path", default="./main") + parser.add_argument("--grid_path", default="grid") + parser.add_argument("--prism_path", default="grid") + parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) + parser.add_argument("--log_dir", default="../log_results/") + parser.add_argument("--evaluations", type=int, default=30 ) + parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + # parser.add_argument("--formula", default="<> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") + parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--num_gpus", type=float, default=0) + parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) + parser.add_argument("--steps", default=20_000, type=int) + parser.add_argument("--expname", default="exp") + parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) + parser.add_argument("--prism_config", default=None) + parser.add_argument("--shield_value", default=0.9, type=float) + parser.add_argument("--probability_displacement", default=1/4, type=float) + parser.add_argument("--probability_intended", default=3/4, type=float) + parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) + # parser.add_argument("--random_starts", default=1, type=int) + args = parser.parse_args() + + return args