From f52262ad114bb1ee30fa3c652d30fca456c5cbf5 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Wed, 23 Aug 2023 14:38:57 +0200 Subject: [PATCH] simple masking (only turn left allowed) --- examples/shields/rl/11_minigridrl.py | 39 ++++++++++++++--------- examples/shields/rl/MaskModels.py | 23 ++++++++------ examples/shields/rl/Wrapper.py | 47 ++++++++++++++++++++++------ 3 files changed, 76 insertions(+), 33 deletions(-) diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index b19309d..f77c7e6 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -82,6 +82,7 @@ def parse_arguments(argparse): parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--prism_path", default="Grid.PRISM") + parser.add_argument("--no_masking", default=False) args = parser.parse_args() @@ -92,14 +93,14 @@ def env_creater_custom(config): # name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") # # name = config.get("name", "MiniGrid-Empty-8x8-v0") framestack = config.get("framestack", 4) - + shield = config.get("shield", {}) # env = gym.make(name) # env = ParametricActionsMiniGridEnv(config) name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) env = gym.make(name) - env = MiniGridEnvWrapper(env) + env = MiniGridEnvWrapper(env, shield=shield) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotWrapper(env, @@ -163,10 +164,21 @@ def create_shield(grid_file, prism_path): assert result.has_scheduler assert result.has_shield shield = result.shield + + action_dictionary = {} + shield_scheduler = shield.construct() + + for stateID in model.states: + choice = shield_scheduler.get_choice(stateID) + choices = choice.choice_map + state_valuation = model.state_valuations.get_string(stateID) + + actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] + + action_dictionary[state_valuation] = actions_to_be_executed stormpy.shields.export_shield(model, shield, "Grid.shield") - - return shield.construct(), model + return action_dictionary def export_grid_to_text(env, grid_file): f = open(grid_file, "w") @@ -195,13 +207,13 @@ def main(): export_grid_to_text(env, grid_file) prism_path = args.prism_path - shield, model = create_shield(grid_file, prism_path) - shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} + shield_dict = create_shield(grid_file, prism_path) + #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} - print(shield_dict) - for state_id in model.states: - choices = shield.get_choice(state_id) - print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") + print(F"Shield dictionary {shield_dict}") + # for state_id in model.states: + # choices = shield.get_choice(state_id) + # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") env_name = "mini-grid" register_env(env_name, env_creater_custom) @@ -213,14 +225,14 @@ def main(): config = (PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) - .environment(env="mini-grid") + .environment(env="mini-grid", env_config={"shield": shield_dict }) .framework("torch") .experimental(_disable_preprocessor_api=False) .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .training(_enable_learner_api=False ,model={ "custom_model": "pa_model", - "custom_model_config" : {"shield": shield_dict, "no_masking": True} + "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} # "fcnet_hiddens": [256,256], # "fcnet_activation": "relu", @@ -231,9 +243,6 @@ def main(): config.build() ) - episode_reward = 0 - terminated = truncated = False - obs, info = env.reset() # while not terminated and not truncated: # action = algo.compute_single_action(obs) diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index b017740..71b4418 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -28,6 +28,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): print(F"Original Space is: {orig_space}") #print(model_config) print(F"Observation space in model: {obs_space}") + print(F"Provided action space in model {action_space}") TorchModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs @@ -37,6 +38,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): assert("shield" in custom_config) self.shield = custom_config["shield"] + self.count = 0 self.internal_model = TorchFC( orig_space["data"], @@ -54,28 +56,31 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. # print(F"Input dict is {input_dict} at obs: {input_dict['obs']}") - # print(F"State is {state}") - - action_mask = [] + # print(F"State is {state}") # print(input_dict["env"]) # Compute the unmasked logits. logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) + # print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}") + + action_mask = input_dict["obs"]["avail_actions"] + #print(F"Action mask is {action_mask} with dimension {action_mask.size()}") + # If action masking is disabled, directly return unmasked logits if self.no_masking: return logits, state - assert(False) - - return logits, state + # assert(False) # Convert action_mask into a [0.0 || -inf]-type mask. - # inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) - # masked_logits = logits + inf_mask + inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) + masked_logits = logits + inf_mask + + print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}") # # Return masked logits. - # return masked_logits, state + return masked_logits, state def value_function(self): return self.internal_model.value_function() diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index c1e5e69..183676f 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -26,7 +26,7 @@ class OneHotWrapper(gym.core.ObservationWrapper): self.observation_space = Dict( { "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), - "avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int), + "avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), } ) @@ -82,34 +82,63 @@ class OneHotWrapper(gym.core.ObservationWrapper): class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env): + def __init__(self, env, shield): super(MiniGridEnvWrapper, self).__init__(env) + self.max_available_actions = env.action_space.n self.observation_space = Dict( { "data": env.observation_space.spaces["image"], - "avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8), + "avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) + self.shield = shield + + + def create_action_mask(self): + coordinates = self.env.agent_pos + view_direction = self.env.agent_dir + print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") + cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + + allowed_actions = [] - def test(self): - print("Testing some stuff") + + # Create the mask + # If shield restricts action mask only valid with 1.0 + # else set everything to one + mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + + # if cur_pos_str in self.shield: + # allowed_actions = self.shield[cur_pos_str] + # for allowed_action in allowed_actions: + # index = allowed_action[0] + # mask[index] = 1.0 + # else: + # for index in len(mask): + # mask[index] = 1.0 + + + print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}") + mask[0] = 1.0 + return mask def reset(self, *, seed=None, options=None): obs, infos = self.env.reset() return { "data": obs["image"], - "avail_actions": np.array([0.0] * 10, dtype=np.int8) + "avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8) }, infos def step(self, action): + print(F"Performed action in step: {action}") orig_obs, rew, done, truncated, info = self.env.step(action) - - self.test() + + actions = self.create_action_mask() #print(F"Original observation is {orig_obs}") obs = { "data": orig_obs["image"], - "avail_actions": np.array([0.0] * 10, dtype=np.int8), + "avail_actions": actions, } #print(F"Info is {info}")