Browse Source

added utils classes

refactoring
Thomas Knoll 11 months ago
parent
commit
5dcabef8e0
  1. 46
      examples/shields/rl/11_minigridrl.py
  2. 33
      examples/shields/rl/12_minigridrl_tune.py
  3. 23
      examples/shields/rl/13_minigridsb.py
  4. 34
      examples/shields/rl/14_train_eval.py
  5. 65
      examples/shields/rl/15_train_eval_tune.py
  6. 209
      examples/shields/rl/rllibutils.py
  7. 68
      examples/shields/rl/sb3utils.py
  8. 124
      examples/shields/rl/utils.py

46
examples/shields/rl/11_minigridrl.py

@ -9,48 +9,12 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from callbacks import CustomCallback
from ray.tune.logger import TBXLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
shield_creator = MiniGridShieldHandler(args.grid_path,
args.grid_to_prism_binary_path,
args.prism_path,
args.formula,
args.shield_value,
args.prism_config,
shield_comparision=args.shield_comparision)
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query,
mask_actions=args.shielding != ShieldingConfig.Disabled,
create_shield_at_reset=args.shield_creation_at_reset)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
@ -71,7 +35,7 @@ def ppo(args):
.resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.callbacks(CustomCallback)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
@ -109,7 +73,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.callbacks(CustomCallback)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": TBXLogger,

33
examples/shields/rl/12_minigridrl_tune.py

@ -11,36 +11,13 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from callbacks import CustomCallback
from torch.utils.tensorboard import SummaryWriter
from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
@ -60,7 +37,7 @@ def ppo(args):
.resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.callbacks(CustomCallback)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
@ -83,7 +60,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.callbacks(CustomCallback)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": TBXLogger,

23
examples/shields/rl/13_minigridsb.py

@ -2,7 +2,6 @@ from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
@ -10,28 +9,13 @@ from minigrid.core.actions import Actions
import time
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None):
super(CustomCallback, self).__init__(verbose)
self.env = env
def _on_step(self) -> bool:
print(self.env.printGrid())
return super()._on_step()
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from sb3utils import MiniGridSbShieldingWrapper
def mask_fn(env: gym.Env):
return env.create_action_mask()
def main():
import argparse
args = parse_arguments(argparse)
@ -44,13 +28,12 @@ def main():
env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
steps = args.steps
model.learn(steps, callback=callback)
model.learn(steps)
#W mean_reward, std_reward = evaluate_policy(model, model.get_env())

34
examples/shields/rl/14_train_eval.py

@ -8,39 +8,13 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from callbacks import MyCallbacks
from callbacks import CustomCallback
from torch.utils.tensorboard import SummaryWriter
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
@ -60,7 +34,7 @@ def ppo(args):
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.callbacks(CustomCallback)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,

65
examples/shields/rl/15_train_eval_tune.py

@ -14,43 +14,11 @@ from ray.rllib.algorithms.callbacks import make_multi_callbacks
from ray.air import session
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig, test_name
from torch.utils.tensorboard import SummaryWriter
from callbacks import MyCallbacks
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
grid_to_prism_path=args.grid_to_prism_binary_path,
prism_path=args.prism_path,
formula=args.formula,
shield_value=args.shield_value,
prism_config=args.prism_config,
shield_comparision=args.shield_comparision)
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
from callbacks import CustomCallback
def register_minigrid_shielding_env(args):
@ -79,7 +47,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},)
.framework("torch")
.callbacks(MyCallbacks)
.callbacks(CustomCallback)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
@ -133,31 +101,6 @@ def ppo(args):
]
pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
# algo = Algorithm.from_checkpoint(best_result.checkpoint)
# eval_log_dir = F"{logdir}-eval"
# writer = SummaryWriter(log_dir=eval_log_dir)
# csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
# for i in range(args.evaluations):
# eval_result = algo.evaluate()
# print(pretty_print(eval_result))
# print(eval_result)
# # logger.on_result(eval_result)
# csv_logger.on_result(eval_result)
# evaluation = eval_result['evaluation']
# epsiode_reward_mean = evaluation['episode_reward_mean']
# episode_len_mean = evaluation['episode_len_mean']
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
# writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
def main():
ray.init(num_cpus=3)
import argparse

209
examples/shields/rl/rllibutils.py

@ -0,0 +1,209 @@
import gymnasium as gym
import numpy as np
import random
from minigrid.core.actions import Actions
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
# print(F"Initial observation in Wrapper {obs}")
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
colors = one_hot(image[:, :, 1], depth=len(COLORS))
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
if not self.mask_actions:
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
return ret
cur_pos_str = self.shield_query_creator(self.env)
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
allowed_actions = []
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
has_allowed_actions = False
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None:
assert(False)
allowed = 1.0
has_allowed_actions = True
mask[index] = allowed
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
# print(F"Mask is {mask} State: {cur_pos_str}")
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env)
mask = self.create_action_mask()
return {
"data": obs["image"],
"action_mask": mask
}, infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
mask = self.create_action_mask()
obs = {
"data": orig_obs["image"],
"action_mask": mask,
}
return obs, rew, done, truncated, info
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
grid_to_prism_path=args.grid_to_prism_binary_path,
prism_path=args.prism_path,
formula=args.formula,
shield_value=args.shield_value,
prism_config=args.prism_config,
shield_comparision=args.shield_comparision)
probability_intended = args.probability_intended
probability_displacement = args.probability_displacement
env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)

68
examples/shields/rl/sb3utils.py

@ -0,0 +1,68 @@
import gymnasium as gym
import numpy as np
import random
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels)
if index is None:
assert(False)
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
shield = self.shield_creator.create_shield(env=self.env)
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info

124
examples/shields/rl/shieldhandlers.py → examples/shields/rl/utils.py

@ -78,7 +78,6 @@ class MiniGridShieldHandler(ShieldHandler):
assert result.has_shield
shield = result.shield
stormpy.shields.export_shield(model, shield, "Grid.shield")
action_dictionary = {}
shield_scheduler = shield.construct()
state_valuations = model.state_valuations
@ -193,4 +192,125 @@ def create_shield_query(env):
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries"):
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaSlipperyS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
"MiniGrid-LavaSlipperyS12-v1",
"MiniGrid-LavaSlipperyS12-v2",
"MiniGrid-LavaSlipperyS12-v3",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args
Loading…
Cancel
Save