From afc9f5bc4d142814891a8db0d67483929502e66c Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Fri, 29 Sep 2023 15:38:54 +0200 Subject: [PATCH] added config / some adversary fixes --- adv_config.yaml | 12 +++++++++ examples/shields/rl/callbacks.py | 37 +++++++++++++++++++-------- examples/shields/rl/helpers.py | 14 ++++++---- examples/shields/rl/shieldhandlers.py | 11 +++++--- examples/shields/rl/wrappers.py | 6 +++-- 5 files changed, 59 insertions(+), 21 deletions(-) create mode 100644 adv_config.yaml diff --git a/adv_config.yaml b/adv_config.yaml new file mode 100644 index 0000000..76ac00d --- /dev/null +++ b/adv_config.yaml @@ -0,0 +1,12 @@ +--- +labels: + - label: "AgentIsInGoal" + text: "AgentIsInGoal" + - label: "AgentRanIntoAdversary" + text: "AgentRanIntoAdversary" + +formulas: + - formula: "AgentRanIntoAdversary" + content: "(xAgent=xBlue) & (yAgent=yBlue)" + +... diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 57c4119..fafd36d 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -1,8 +1,9 @@ -from typing import Dict +from typing import Dict, Optional +from ray.rllib.env.env_context import EnvContext from ray.rllib.policy import Policy -from ray.rllib.utils.typing import PolicyID +from ray.rllib.utils.typing import EnvType, PolicyID from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.env.base_env import BaseEnv @@ -12,6 +13,10 @@ from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks +import matplotlib.pyplot as plt + + + class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") @@ -19,8 +24,11 @@ class MyCallbacks(DefaultCallbacks): episode.user_data["count"] = 0 episode.user_data["ran_into_lava"] = [] episode.user_data["goals_reached"] = [] + episode.user_data["ran_into_adversary"] = [] episode.hist_data["ran_into_lava"] = [] episode.hist_data["goals_reached"] = [] + episode.hist_data["ran_into_adversary"] = [] + # print("On episode start print") # print(env.printGrid()) # print(worker) @@ -28,30 +36,39 @@ class MyCallbacks(DefaultCallbacks): # print(env.actions) # print(env.mission) # print(env.observation_space) - # img = env.get_frame() # plt.imshow(img) # plt.show() - + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: - episode.user_data["count"] = episode.user_data["count"] + 1 - env = base_env.get_sub_environments()[0] - # print(env.printGrid()) + episode.user_data["count"] = episode.user_data["count"] + 1 + env = base_env.get_sub_environments()[0] + # print(env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) - + ran_into_adversary = False + + if hasattr(env, "adversaries"): + adversaries = env.adversaries.values() + for adversary in adversaries: + if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: + ran_into_adversary = True + break + episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") + episode.user_data["ran_into_adversary"].append(ran_into_adversary) episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal" episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava" + episode.custom_metrics["ran_into_adversary"] = ran_into_adversary #print("On episode end print") - #print(env.printGrid()) + # print(env.printGrid()) episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] - + episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"] def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: print("Evaluate Start") diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 20dac9e..684065c 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -71,6 +71,9 @@ def test_name(args): def get_action_index_mapping(actions): for action_str in actions: + if not "Agent" in action_str: + continue + if "move" in action_str: return Actions.forward elif "left" in action_str: @@ -88,9 +91,8 @@ def get_action_index_mapping(actions): elif "unlock" in action_str: return Actions.toggle - return Actions.done - - + raise ValueError("No action mapping found") + def parse_arguments(argparse): parser = argparse.ArgumentParser() @@ -100,6 +102,8 @@ def parse_arguments(argparse): default="MiniGrid-LavaCrossingS9N1-v0", choices=[ "MiniGrid-Adv-8x8-v0", + "MiniGrid-AdvSimple-8x8-v0", + "MiniGrid-SingleDoor-7x6-v0", "MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N3-v0", "MiniGrid-LavaSlipperyS12-v0", @@ -110,7 +114,6 @@ def parse_arguments(argparse): # "MiniGrid-DoubleDoor-16x16-v0", # "MiniGrid-DoubleDoor-12x12-v0", # "MiniGrid-DoubleDoor-10x8-v0", - # "MiniGrid-SingleDoor-7x6-v0", # "MiniGrid-LockedRoom-v0", # "MiniGrid-FourRooms-v0", # "MiniGrid-LavaGapS7-v0", @@ -126,7 +129,8 @@ def parse_arguments(argparse): parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--evaluations", type=int, default=10 ) - parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + # parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + parser.add_argument("--formula", default="Pmax=? [G !\"AgentRanIntoAdversary\"]") parser.add_argument("--workers", type=int, default=1) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/shieldhandlers.py index 3cf95a3..30d15ff 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/shieldhandlers.py @@ -40,13 +40,12 @@ class MiniGridShieldHandler(ShieldHandler): def __create_prism(self): - result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path}") + result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") assert result == 0, "Prism file could not be generated" f = open(self.prism_path, "a") f.write("label \"AgentIsInLava\" = AgentIsInLava;") - f.write("label \"AgentIsInGoal\" = AgentIsInGoal;") f.close() def __create_shield_dict(self): @@ -63,7 +62,6 @@ class MiniGridShieldHandler(ShieldHandler): result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) - assert result.has_scheduler assert result.has_shield shield = result.shield stormpy.shields.export_shield(model, shield, "Grid.shield") @@ -172,8 +170,13 @@ def create_shield_query(env): if key_positions: key_positions_text = F"\t& {''.join(key_positions)}" + move_text = "" + + if adversaries: + move_text = F"move=0\t& " + agent_position = F"xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}" - query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" + query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" return query \ No newline at end of file diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index 4f1d94d..7534775 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -110,7 +110,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): cur_pos_str = self.shield_query_creator(self.env) # print(F"Pos string {cur_pos_str}") # print(F"Shield {list(self.shield.keys())[0]}") - # print(cur_pos_str in self.shield) + # print(F"Is pos str in shield: {cur_pos_str in self.shield}") # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid @@ -127,6 +127,8 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): assert(False) allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] + if allowed_action.prob == 0 and allowed: + assert False mask[index] = allowed else: @@ -141,7 +143,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 - + # print(F"Mask is {mask} State: {cur_pos_str}") return mask def reset(self, *, seed=None, options=None):