From 757fbbcc0df24f790d209756f03c15456faf6227 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Mon, 4 Sep 2023 12:06:00 +0200 Subject: [PATCH] fixed shield generation worker handling --- examples/shields/rl/11_minigridrl.py | 32 +++++++-------------- examples/shields/rl/13_minigridsb.py | 10 ++----- examples/shields/rl/MaskModels.py | 14 ++-------- examples/shields/rl/Wrapper.py | 11 +------- examples/shields/rl/helpers.py | 42 ++++++++-------------------- 5 files changed, 28 insertions(+), 81 deletions(-) diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index ff7d8e1..4162189 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -22,10 +22,9 @@ from ray.rllib.models import ModelCatalog from ray.rllib.utils.torch_utils import FLOAT_MIN -from ray.rllib.models.preprocessors import get_preprocessor from MaskModels import TorchActionMaskModel from Wrapper import OneHotWrapper, MiniGridEnvWrapper -from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir +from helpers import parse_arguments, create_log_dir import matplotlib.pyplot as plt @@ -34,7 +33,9 @@ class MyCallbacks(DefaultCallbacks): # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0 + # print("On episode start print") # print(env.printGrid()) + # print(worker) # print(env.action_space.n) # print(env.actions) # print(env.mission) @@ -52,17 +53,19 @@ class MyCallbacks(DefaultCallbacks): def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] - # print(env.printGrid()) - # print(episode.user_data["count"]) + #print("On episode end print") + #print(env.printGrid()) def env_creater_custom(config): framestack = config.get("framestack", 4) - shield = config.get("shield", {}) name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) args = config.get("args", None) + args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + env = gym.make(name) env = MiniGridEnvWrapper(env, args=args) # env = minigrid.wrappers.ImgObsWrapper(env) @@ -88,14 +91,10 @@ def register_custom_minigrid_env(args): def ppo(args): - - ray.init(num_cpus=1) - - register_custom_minigrid_env(args) config = (PPOConfig() - .rollouts(num_rollout_workers=1) + .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) .environment(env="mini-grid", env_config={"name": args.env, "args": args}) .framework("torch") @@ -123,15 +122,6 @@ def ppo(args): checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") - # terminated = truncated = False - - # while not terminated and not truncated: - # action = algo.compute_single_action(obs) - # obs, reward, terminated, truncated = env.step(action) - - - ray.shutdown() - def dqn(args): register_custom_minigrid_env(args) @@ -139,7 +129,7 @@ def dqn(args): config = DQNConfig() config = config.resources(num_gpus=0) - config = config.rollouts(num_rollout_workers=1) + config = config.rollouts(num_rollout_workers=args.workers) config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.framework("torch") config = config.callbacks(MyCallbacks) @@ -166,8 +156,6 @@ def dqn(args): checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") - ray.shutdown() - def main(): import argparse diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 6a2fb20..7959319 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -93,7 +93,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - + keys = extract_keys(self.env) shield = create_shield_dict(self.env, self.args) @@ -102,13 +102,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper): return obs["image"], infos def step(self, action): - # print(F"Performed action in step: {action}") - orig_obs, rew, done, truncated, info = self.env.step(action) - - #print(F"Original observation is {orig_obs}") + orig_obs, rew, done, truncated, info = self.env.step(action) obs = orig_obs["image"] - - #print(F"Info is {info}") + return obs, rew, done, truncated, info diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index 0ee4154..e882a51 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -24,10 +24,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): ): orig_space = getattr(obs_space, "original_space", obs_space) custom_config = model_config['custom_model_config'] - # print(F"Original Space is: {orig_space}") - #print(model_config) - #print(F"Observation space in model: {obs_space}") - #print(F"Provided action space in model {action_space}") TorchModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs @@ -43,17 +39,15 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): model_config, name + "_internal", ) - - # disable action masking --> will likely lead to invalid actions + self.no_masking = False if "no_masking" in model_config["custom_model_config"]: self.no_masking = model_config["custom_model_config"]["no_masking"] def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. - # Compute the unmasked logits. + # Compute the unmasked logits. logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) - action_mask = input_dict["obs"]["action_mask"] @@ -61,13 +55,11 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): if self.no_masking: return logits, state - # assert(False) - # Convert action_mask into a [0.0 || -inf]-type mask. inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) masked_logits = logits + inf_mask - # # Return masked logits. + # Return masked logits. return masked_logits, state def value_function(self): diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index 6650369..390974b 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -34,10 +34,6 @@ class OneHotWrapper(gym.core.ObservationWrapper): } ) - - # print(F"Set obersvation space to {self.observation_space}") - - def observation(self, obs): # Debug output: max-x/y positions to watch exploration progress. # print(F"Initial observation in Wrapper {obs}") @@ -80,9 +76,8 @@ class OneHotWrapper(gym.core.ObservationWrapper): single_frame = np.concatenate([all_flat, direction]) self.frame_buffer.append(single_frame) - #obs["one-hot"] = np.concatenate(self.frame_buffer) tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } - return tmp#np.concatenate(self.frame_buffer) + return tmp class MiniGridEnvWrapper(gym.core.Wrapper): @@ -111,7 +106,6 @@ class MiniGridEnvWrapper(gym.core.Wrapper): if self.env.carrying and self.env.carrying.type == "key": key_text = F"Agent_has_{self.env.carrying.color}_key\t& " - #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" allowed_actions = [] @@ -130,7 +124,6 @@ class MiniGridEnvWrapper(gym.core.Wrapper): assert(False) mask[index] = 1.0 else: - # print(F"Not in shield {cur_pos_str}") for index, x in enumerate(mask): mask[index] = 1.0 @@ -158,11 +151,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper): }, infos def step(self, action): - # print(F"Performed action in step: {action}") orig_obs, rew, done, truncated, info = self.env.step(action) mask = self.create_action_mask() - #print(F"Original observation is {orig_obs}") obs = { "data": orig_obs["image"], "action_mask": mask, diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index ca15f14..9106745 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -18,7 +18,6 @@ import os def extract_keys(env): - env.reset() keys = [] #print(env.grid) for j in range(env.grid.height): @@ -66,6 +65,7 @@ def parse_arguments(argparse): default="MiniGrid-LavaCrossingS9N1-v0", choices=[ "MiniGrid-LavaCrossingS9N1-v0", + "MiniGrid-LavaCrossingS9N3-v0", "MiniGrid-DoorKey-8x8-v0", "MiniGrid-LockedRoom-v0", "MiniGrid-FourRooms-v0", @@ -77,26 +77,21 @@ def parse_arguments(argparse): # parser.add_argument("--seed", type=int, help="seed for environment", default=None) parser.add_argument("--grid_to_prism_path", default="./main") - parser.add_argument("--grid_path", default="Grid.txt") - parser.add_argument("--prism_path", default="Grid.PRISM") + parser.add_argument("--grid_path", default="grid") + parser.add_argument("--prism_path", default="grid") parser.add_argument("--no_masking", default=False) parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--iterations", type=int, default=30 ) + parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + parser.add_argument("--workers", type=int, default=1) + args = parser.parse_args() return args - -def create_environment(args): - env_id= args.env - env = gym.make(env_id) - env.reset() - return env - - def export_grid_to_text(env, grid_file): f = open(grid_file, "w") # print(env) @@ -104,24 +99,18 @@ def export_grid_to_text(env, grid_file): f.close() -def create_shield(grid_to_prism_path, grid_file, prism_path): +def create_shield(grid_to_prism_path, grid_file, prism_path, formula): os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") f = open(prism_path, "a") f.write("label \"AgentIsInLava\" = AgentIsInLava;") f.close() - - + program = stormpy.parse_prism_program(prism_path) - formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" - # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, - # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) - + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) - # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.9) - formulas = stormpy.parse_properties_for_prism_program(formula_str, program) + formulas = stormpy.parse_properties_for_prism_program(formula, program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) options.set_build_state_valuations(True) options.set_build_choice_labels(True) @@ -151,21 +140,12 @@ def create_shield(grid_to_prism_path, grid_file, prism_path): def create_shield_dict(env, args): - env = create_environment(args) - # print(env.printGrid(init=False)) - grid_file = args.grid_path grid_to_prism_path = args.grid_to_prism_path export_grid_to_text(env, grid_file) prism_path = args.prism_path - shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path) - #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} - - #print(F"Shield dictionary {shield_dict}") - # for state_id in model.states: - # choices = shield.get_choice(state_id) - # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") + shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path, args.formula) return shield_dict