diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index ace2639..497f48b 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -37,6 +37,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN from ray.rllib.models.preprocessors import get_preprocessor from MaskModels import TorchActionMaskModel from Wrapper import OneHotWrapper, MiniGridEnvWrapper +from helpers import extract_keys import matplotlib.pyplot as plt @@ -61,12 +62,12 @@ class MyCallbacks(DefaultCallbacks): def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] - #print(env.env.env.printGrid()) + #print(env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] - # print(env.env.env.printGrid()) + # print(env.printGrid()) # print(episode.user_data["count"]) @@ -96,17 +97,21 @@ def parse_arguments(argparse): "MiniGrid-RedBlueDoors-6x6-v0",]) # parser.add_argument("--seed", type=int, help="seed for environment", default=None) + parser.add_argument("--grid_to_prism_path", default="./main") parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--prism_path", default="Grid.PRISM") parser.add_argument("--no_masking", default=False) parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) parser.add_argument("--log_dir", default="../log_results/") + parser.add_argument("--iterations", type=int, default=30 ) args = parser.parse_args() return args + + def env_creater_custom(config): framestack = config.get("framestack", 4) shield = config.get("shield", {}) @@ -114,7 +119,8 @@ def env_creater_custom(config): framestack = config.get("framestack", 4) env = gym.make(name) - env = MiniGridEnvWrapper(env, shield=shield) + keys = extract_keys(env) + env = MiniGridEnvWrapper(env, shield=shield, keys=keys) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotWrapper(env, @@ -142,10 +148,12 @@ def env_creater(config): return env +def create_log_dir(args): + return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}" -def create_shield(grid_file, prism_path): - os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") +def create_shield(grid_to_prism_path, grid_file, prism_path): + os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") f = open(prism_path, "a") f.write("label \"AgentIsInLava\" = AgentIsInLava;") @@ -154,7 +162,12 @@ def create_shield(grid_file, prism_path): program = stormpy.parse_prism_program(prism_path) formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - #formula_str = "Pmax=? [G \"AgentIsInGoalAndNotDone\"]" + # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, + stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) + + + # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) formulas = stormpy.parse_properties_for_prism_program(formula_str, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) @@ -163,7 +176,6 @@ def create_shield(grid_file, prism_path): options.set_build_all_labels() model = stormpy.build_sparse_model_with_options(program, options) - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) assert result.has_scheduler @@ -189,7 +201,6 @@ def export_grid_to_text(env, grid_file): f = open(grid_file, "w") # print(env) f.write(env.printGrid(init=True)) - # f.write(env.pprint_grid()) f.close() def create_environment(args): @@ -210,14 +221,14 @@ def register_custom_minigrid_env(args): def create_shield_dict(args): env = create_environment(args) - # print(env.pprint_grid()) # print(env.printGrid(init=False)) grid_file = args.grid_path + grid_to_prism_path = args.grid_to_prism_path export_grid_to_text(env, grid_file) prism_path = args.prism_path - shield_dict = create_shield(grid_file, prism_path) + shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path) #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} #print(F"Shield dictionary {shield_dict}") @@ -238,13 +249,13 @@ def ppo(args): config = (PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) - .environment(env="mini-grid", env_config={"shield": shield_dict}) + .environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env}) .framework("torch") .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": "ray.tune.logger.TBXLogger", - "logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}" + "logdir": create_log_dir(args) }) .training(_enable_learner_api=False ,model={ "custom_model": "pa_model", @@ -279,13 +290,13 @@ def dqn(args): config = DQNConfig() config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=1) - config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) + config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env }) config = config.framework("torch") config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ "type": "ray.tune.logger.TBXLogger", - "logdir": F"{args.log_dir}{datetime.now()}-{args.algorithm}" + "logdir": create_log_dir(args) }) config = config.training(hiddens=[], dueling=False, model={ "custom_model": "pa_model", @@ -296,13 +307,13 @@ def dqn(args): config.build() ) - for i in range(30): + for i in range(args.iterations): result = algo.train() print(pretty_print(result)) - if i % 5 == 0: - checkpoint_dir = algo.save() - print(f"Checkpoint saved in directory {checkpoint_dir}") + # if i % 5 == 0: + # checkpoint_dir = algo.save() + # print(f"Checkpoint saved in directory {checkpoint_dir}") ray.shutdown() diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index 6118e2d..d67e364 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -6,6 +6,8 @@ from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot +from helpers import get_action_index_mapping + class OneHotWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) @@ -82,7 +84,7 @@ class OneHotWrapper(gym.core.ObservationWrapper): class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env, shield={}): + def __init__(self, env, shield={}, keys=[]): super(MiniGridEnvWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( @@ -91,29 +93,43 @@ class MiniGridEnvWrapper(gym.core.Wrapper): "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) + self.keys = keys self.shield = shield - def create_action_mask(self): coordinates = self.env.agent_pos view_direction = self.env.agent_dir + + key_text = "" + + # only support one key for now + if self.keys: + key_text = F"!Agent_has_{self.keys[0]}_key\t& " + + + if self.env.carrying and self.env.carrying.type == "key": + key_text = F"Agent_has_{self.env.carrying.color}_key\t& " + #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") - cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" allowed_actions = [] # Create the mask # If shield restricts action mask only valid with 1.0 - # else set everything to one + # else set all actions as valid mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] for allowed_action in allowed_actions: - index = allowed_action[0] + index = get_action_index_mapping(allowed_action[1]) + if index is None: + assert(False) mask[index] = 1.0 else: + print("Not in shield") for index, x in enumerate(mask): mask[index] = 1.0 diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py new file mode 100644 index 0000000..aa8738b --- /dev/null +++ b/examples/shields/rl/helpers.py @@ -0,0 +1,39 @@ +import minigrid +from minigrid.core.actions import Actions + + +def extract_keys(env): + env.reset() + keys = [] + print(env.grid) + for j in range(env.grid.height): + for i in range(env.grid.width): + obj = env.grid.get(i,j) + + if obj and obj.type == "key": + keys.append(obj.color) + + return keys + + +def get_action_index_mapping(actions): + for action_str in actions: + if "left" in action_str: + return Actions.left + elif "right" in action_str: + return Actions.right + elif "east" in action_str: + return Actions.forward + elif "south" in action_str: + return Actions.forward + elif "west" in action_str: + return Actions.forward + elif "north" in action_str: + return Actions.forward + elif "pickup" in action_str: + return Actions.pickup + elif "done" in action_str: + return Actions.done + + + raise ValueError(F"Action string {action_str} not supported") \ No newline at end of file