diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py old mode 100644 new mode 100755 index fcdf9dd..569d0dc --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -20,13 +20,15 @@ def shielding_env_creater(config): name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) args = config.get("args", None) - args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" - args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism" - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config) env = gym.make(name) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, + mask_actions=args.shielding != ShieldingConfig.Disabled, create_shield_at_reset=args.shield_creation_at_reset) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) @@ -63,15 +65,18 @@ def ppo(args): .debugging(logger_config={ "type": TBXLogger, "logdir": create_log_dir(args) - }) + }) + # .exploration(exploration_config={"exploration_fraction": 0.1}) .training(_enable_learner_api=False ,model={ "custom_model": "shielding_model" })) - + # config.entropy_coeff = 0.05 algo =( config.build() - ) + ) + + for i in range(args.evaluations): result = algo.train() print(pretty_print(result)) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index ba8c8e2..6a8f0df 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -33,8 +33,8 @@ def shielding_env_creater(config): prism_path=args.prism_path, formula=args.formula) - env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) + env = gym.make(name, randomize_start=True) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index fafd36d..240fec7 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -44,6 +44,14 @@ class MyCallbacks(DefaultCallbacks): episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] # print(env.printGrid()) + + if hasattr(env, "adversaries"): + for adversary in env.adversaries.values(): + if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: + print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}") + # assert False + + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 684065c..3e8112f 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -64,7 +64,7 @@ def extract_adversaries(env): return adv def create_log_dir(args): - return F"{args.log_dir}sh:{args.shielding}-env:{args.env}" + return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-env:{args.env}-conf:{args.prism_config}" def test_name(args): return F"{args.expname}" @@ -99,7 +99,7 @@ def parse_arguments(argparse): # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") parser.add_argument("--env", help="gym environment to load", - default="MiniGrid-LavaCrossingS9N1-v0", + default="MiniGrid-LavaSlipperyS12-v2", choices=[ "MiniGrid-Adv-8x8-v0", "MiniGrid-AdvSimple-8x8-v0", @@ -128,14 +128,17 @@ def parse_arguments(argparse): parser.add_argument("--prism_path", default="grid") parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--evaluations", type=int, default=10 ) - # parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" - parser.add_argument("--formula", default="Pmax=? [G !\"AgentRanIntoAdversary\"]") + parser.add_argument("--evaluations", type=int, default=30 ) + parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + # parser.add_argument("--formula", default="<> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]") parser.add_argument("--workers", type=int, default=1) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) parser.add_argument("--expname", default="exp") parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction) + parser.add_argument("--prism_config", default=None) + parser.add_argument("--shield_value", default=0.9, type=float) + # parser.add_argument("--random_starts", default=1, type=int) args = parser.parse_args() return args diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/shieldhandlers.py index 30d15ff..68140a8 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/shieldhandlers.py @@ -27,11 +27,13 @@ class ShieldHandler(ABC): pass class MiniGridShieldHandler(ShieldHandler): - def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None: + def __init__(self, grid_file, grid_to_prism_path, prism_path, formula, shield_value=0.9 ,prism_config=None) -> None: self.grid_file = grid_file self.grid_to_prism_path = grid_to_prism_path self.prism_path = prism_path self.formula = formula + self.prism_config = prism_config + self.shield_value = shield_value def __export_grid_to_text(self, env): f = open(self.grid_file, "w") @@ -40,7 +42,12 @@ class MiniGridShieldHandler(ShieldHandler): def __create_prism(self): - result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") + # result = os.system(F"{self.grid_to_prism_path} -v 'Agent,Blue' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") + if self.prism_config is None: + result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path}") + else: + result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}") + # result = os.system(F"{self.grid_to_prism_path} -v 'Agent' -i {self.grid_file} -o {self.prism_path} -c adv_config.yaml") assert result == 0, "Prism file could not be generated" @@ -51,7 +58,7 @@ class MiniGridShieldHandler(ShieldHandler): def __create_shield_dict(self): print(self.prism_path) program = stormpy.parse_prism_program(self.prism_path) - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, self.shield_value) formulas = stormpy.parse_properties_for_prism_program(self.formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) @@ -175,7 +182,7 @@ def create_shield_query(env): if adversaries: move_text = F"move=0\t& " - agent_position = F"xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}" + agent_position = F"& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}" query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]" return query diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index 7534775..e706ff0 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -102,15 +102,20 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): self.shield = shield_creator.create_shield(env=self.env) self.mask_actions = mask_actions self.shield_query_creator = shield_query_creator + print(F"Shielding is {self.mask_actions}") def create_action_mask(self): + # print(F"{self.mask_actions} No shielding") if not self.mask_actions: - return np.array([1.0] * self.max_available_actions, dtype=np.int8) + # print("No shielding") + ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) + # print(ret) + return ret cur_pos_str = self.shield_query_creator(self.env) # print(F"Pos string {cur_pos_str}") # print(F"Shield {list(self.shield.keys())[0]}") - # print(F"Is pos str in shield: {cur_pos_str in self.shield}") + # print(F"Is pos str in shield: {cur_pos_str in self.shield}, Position Str {cur_pos_str}") # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid @@ -119,6 +124,9 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] + zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) + has_allowed_actions = False + for allowed_action in allowed_actions: index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set if index is None: @@ -129,8 +137,14 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] if allowed_action.prob == 0 and allowed: assert False + if allowed: + has_allowed_actions = True mask[index] = allowed - + + # if not has_allowed_actions: + # print(F"No action allowed for pos string {cur_pos_str}") + # assert(False) + else: for index, x in enumerate(mask): mask[index] = 1.0 @@ -186,6 +200,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): self.shield_creator = shield_creator self.mask_actions = mask_actions self.shield_query_creator = shield_query_creator + print(F"Shielding is {self.mask_actions}") def create_action_mask(self): if not self.mask_actions: