Browse Source

added utils classes

refactoring
Thomas Knoll 1 year ago
parent
commit
5dcabef8e0
  1. 46
      examples/shields/rl/11_minigridrl.py
  2. 33
      examples/shields/rl/12_minigridrl_tune.py
  3. 23
      examples/shields/rl/13_minigridsb.py
  4. 34
      examples/shields/rl/14_train_eval.py
  5. 65
      examples/shields/rl/15_train_eval_tune.py
  6. 209
      examples/shields/rl/rllibutils.py
  7. 68
      examples/shields/rl/sb3utils.py
  8. 124
      examples/shields/rl/utils.py

46
examples/shields/rl/11_minigridrl.py

@ -9,48 +9,12 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from helpers import parse_arguments, create_log_dir, ShieldingConfig from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query from callbacks import CustomCallback
from callbacks import MyCallbacks
from ray.tune.logger import TBXLogger from ray.tune.logger import TBXLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
shield_creator = MiniGridShieldHandler(args.grid_path,
args.grid_to_prism_binary_path,
args.prism_path,
args.formula,
args.shield_value,
args.prism_config,
shield_comparision=args.shield_comparision)
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query,
mask_actions=args.shielding != ShieldingConfig.Disabled,
create_shield_at_reset=args.shield_creation_at_reset)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args): def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding" env_name = "mini-grid-shielding"
@ -71,7 +35,7 @@ def ppo(args):
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(CustomCallback)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,
@ -109,7 +73,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers) config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
config = config.callbacks(MyCallbacks) config = config.callbacks(CustomCallback)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={ config = config.debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,

33
examples/shields/rl/12_minigridrl_tune.py

@ -11,36 +11,13 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query from callbacks import CustomCallback
from callbacks import MyCallbacks
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args): def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding" env_name = "mini-grid-shielding"
@ -60,7 +37,7 @@ def ppo(args):
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(CustomCallback)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,
@ -83,7 +60,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers) config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
config = config.callbacks(MyCallbacks) config = config.callbacks(CustomCallback)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={ config = config.debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,

23
examples/shields/rl/13_minigridsb.py

@ -2,7 +2,6 @@ from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym import gymnasium as gym
@ -10,28 +9,13 @@ from minigrid.core.actions import Actions
import time import time
from helpers import parse_arguments, create_log_dir, ShieldingConfig from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query from sb3utils import MiniGridSbShieldingWrapper
from wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None):
super(CustomCallback, self).__init__(verbose)
self.env = env
def _on_step(self) -> bool:
print(self.env.printGrid())
return super()._on_step()
def mask_fn(env: gym.Env): def mask_fn(env: gym.Env):
return env.create_action_mask() return env.create_action_mask()
def main(): def main():
import argparse import argparse
args = parse_arguments(argparse) args = parse_arguments(argparse)
@ -44,13 +28,12 @@ def main():
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
steps = args.steps steps = args.steps
model.learn(steps, callback=callback) model.learn(steps)
#W mean_reward, std_reward = evaluate_policy(model, model.get_env()) #W mean_reward, std_reward = evaluate_policy(model, model.get_env())

34
examples/shields/rl/14_train_eval.py

@ -8,39 +8,13 @@ from ray.rllib.models import ModelCatalog
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from helpers import parse_arguments, create_log_dir, ShieldingConfig from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks from callbacks import CustomCallback
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args): def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding" env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater) register_env(env_name, shielding_env_creater)
@ -60,7 +34,7 @@ def ppo(args):
.environment( env="mini-grid-shielding", .environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(CustomCallback)
.evaluation(evaluation_config={ .evaluation(evaluation_config={
"evaluation_interval": 1, "evaluation_interval": 1,
"evaluation_duration": 10, "evaluation_duration": 10,

65
examples/shields/rl/15_train_eval_tune.py

@ -14,43 +14,11 @@ from ray.rllib.algorithms.callbacks import make_multi_callbacks
from ray.air import session from ray.air import session
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig, test_name
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from callbacks import MyCallbacks from callbacks import CustomCallback
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
grid_to_prism_path=args.grid_to_prism_binary_path,
prism_path=args.prism_path,
formula=args.formula,
shield_value=args.shield_value,
prism_config=args.prism_config,
shield_comparision=args.shield_comparision)
prob_forward = args.prob_forward
prob_direct = args.prob_direct
prob_next = args.prob_next
env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args): def register_minigrid_shielding_env(args):
@ -79,7 +47,7 @@ def ppo(args):
"shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
},) },)
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(CustomCallback)
.evaluation(evaluation_config={ .evaluation(evaluation_config={
"evaluation_interval": 1, "evaluation_interval": 1,
"evaluation_duration": 10, "evaluation_duration": 10,
@ -133,31 +101,6 @@ def ppo(args):
] ]
pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print}) pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
# algo = Algorithm.from_checkpoint(best_result.checkpoint)
# eval_log_dir = F"{logdir}-eval"
# writer = SummaryWriter(log_dir=eval_log_dir)
# csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
# for i in range(args.evaluations):
# eval_result = algo.evaluate()
# print(pretty_print(eval_result))
# print(eval_result)
# # logger.on_result(eval_result)
# csv_logger.on_result(eval_result)
# evaluation = eval_result['evaluation']
# epsiode_reward_mean = evaluation['episode_reward_mean']
# episode_len_mean = evaluation['episode_len_mean']
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
# writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
def main(): def main():
ray.init(num_cpus=3) ray.init(num_cpus=3)
import argparse import argparse

209
examples/shields/rl/rllibutils.py

@ -0,0 +1,209 @@
import gymnasium as gym
import numpy as np
import random
from minigrid.core.actions import Actions
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
# 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
# +4: Direction.
self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
self.init_x = None
self.init_y = None
self.x_positions = []
self.y_positions = []
self.x_y_delta_buffer = deque(maxlen=100)
self.vector_index = vector_index
self.frame_buffer = deque(maxlen=self.framestack)
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
def observation(self, obs):
# Debug output: max-x/y positions to watch exploration progress.
# print(F"Initial observation in Wrapper {obs}")
if self.step_count == 0:
for _ in range(self.framestack):
self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
if self.vector_index == 0:
if self.x_positions:
max_diff = max(
np.sqrt(
(np.array(self.x_positions) - self.init_x) ** 2
+ (np.array(self.y_positions) - self.init_y) ** 2
)
)
self.x_y_delta_buffer.append(max_diff)
print(
"100-average dist travelled={}".format(
np.mean(self.x_y_delta_buffer)
)
)
self.x_positions = []
self.y_positions = []
self.init_x = self.agent_pos[0]
self.init_y = self.agent_pos[1]
self.x_positions.append(self.agent_pos[0])
self.y_positions.append(self.agent_pos[1])
image = obs["data"]
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
colors = one_hot(image[:, :, 1], depth=len(COLORS))
states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
all_ = np.concatenate([objects, colors, states], -1)
all_flat = np.reshape(all_, (-1,))
direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
single_frame = np.concatenate([all_flat, direction])
self.frame_buffer.append(single_frame)
tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
return tmp
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
print(F"Shielding is {self.mask_actions}")
def create_action_mask(self):
if not self.mask_actions:
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
return ret
cur_pos_str = self.shield_query_creator(self.env)
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
allowed_actions = []
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
has_allowed_actions = False
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None:
assert(False)
allowed = 1.0
has_allowed_actions = True
mask[index] = allowed
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
# print(F"Mask is {mask} State: {cur_pos_str}")
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env)
mask = self.create_action_mask()
return {
"data": obs["image"],
"action_mask": mask
}, infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
mask = self.create_action_mask()
obs = {
"data": orig_obs["image"],
"action_mask": mask,
}
return obs, rew, done, truncated, info
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
grid_to_prism_path=args.grid_to_prism_binary_path,
prism_path=args.prism_path,
formula=args.formula,
shield_value=args.shield_value,
prism_config=args.prism_config,
shield_comparision=args.shield_comparision)
probability_intended = args.probability_intended
probability_displacement = args.probability_displacement
env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)

68
examples/shields/rl/sb3utils.py

@ -0,0 +1,68 @@
import gymnasium as gym
import numpy as np
import random
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels)
if index is None:
assert(False)
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
else:
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
shield = self.shield_creator.create_shield(env=self.env)
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info

124
examples/shields/rl/shieldhandlers.py → examples/shields/rl/utils.py

@ -78,7 +78,6 @@ class MiniGridShieldHandler(ShieldHandler):
assert result.has_shield assert result.has_shield
shield = result.shield shield = result.shield
stormpy.shields.export_shield(model, shield, "Grid.shield")
action_dictionary = {} action_dictionary = {}
shield_scheduler = shield.construct() shield_scheduler = shield.construct()
state_valuations = model.state_valuations state_valuations = model.state_valuations
@ -193,4 +192,125 @@ def create_shield_query(env):
query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
return query return query
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Full = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "key":
keys.append((obj, i, j))
if env.carrying and env.carrying.type == "key":
keys.append((env.carrying, -1, -1))
# TODO Maybe need to add ordering of keys so it matches the order in the shield
return keys
def extract_doors(env):
doors = []
for j in range(env.grid.height):
for i in range(env.grid.width):
obj = env.grid.get(i,j)
if obj and obj.type == "door":
doors.append(obj)
return doors
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries"):
return []
for color, adversary in env.adversaries.items():
adv.append(adversary)
return adv
def create_log_dir(args):
return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
def test_name(args):
return F"{args.expname}"
def get_action_index_mapping(actions):
for action_str in actions:
if not "Agent" in action_str:
continue
if "move" in action_str:
return Actions.forward
elif "left" in action_str:
return Actions.left
elif "right" in action_str:
return Actions.right
elif "pickup" in action_str:
return Actions.pickup
elif "done" in action_str:
return Actions.done
elif "drop" in action_str:
return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError("No action mapping found")
def parse_arguments(argparse):
parser = argparse.ArgumentParser()
# parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
parser.add_argument("--env",
help="gym environment to load",
default="MiniGrid-LavaSlipperyS12-v2",
choices=[
"MiniGrid-Adv-8x8-v0",
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyS12-v0",
"MiniGrid-LavaSlipperyS12-v1",
"MiniGrid-LavaSlipperyS12-v2",
"MiniGrid-LavaSlipperyS12-v3",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--evaluations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--num_gpus", type=float, default=0)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
parser.add_argument("--expname", default="exp")
parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
parser.add_argument("--prism_config", default=None)
parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args()
return args
|||||||
100:0
Loading…
Cancel
Save