diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index d2f125c..fcdf9dd 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -24,9 +24,10 @@ def shielding_env_creater(config): args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) - env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, + shield_query_creator=create_shield_query, + create_shield_at_reset=args.shield_creation_at_reset) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotShieldingWrapper(env, @@ -78,6 +79,8 @@ def ppo(args): if i % 5 == 0: checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") + + algo.save() def dqn(args): diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 7e97e27..9c5b520 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -71,8 +71,12 @@ def get_action_index_mapping(actions): return Actions.done elif "drop" in action_str: return Actions.drop + elif "toggle" in action_str: + return Actions.toggle + elif "unlock" in action_str: + return Actions.toggle - raise ValueError(F"Action string {action_str} not supported") + return Actions.done diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/shieldhandlers.py index d4271e4..1f95228 100644 --- a/examples/shields/rl/shieldhandlers.py +++ b/examples/shields/rl/shieldhandlers.py @@ -66,7 +66,7 @@ class MiniGridShieldHandler(ShieldHandler): assert result.has_scheduler assert result.has_shield shield = result.shield - # stormpy.shields.export_shield(model, shield, "Grid.shield") + stormpy.shields.export_shield(model, shield, "Grid.shield") action_dictionary = {} shield_scheduler = shield.construct() diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index 3f0c2f8..74d0c7e 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -8,7 +8,7 @@ from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot -from helpers import get_action_index_mapping, extract_keys +from helpers import get_action_index_mapping from shieldhandlers import ShieldHandler @@ -67,7 +67,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): image = obs["data"] - # One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten. + # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. objects = one_hot(image[:, :, 0], depth=16) colors = one_hot(image[:, :, 1], depth=6) states = one_hot(image[:, :, 2], depth=3) @@ -108,7 +108,9 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): return np.array([1.0] * self.max_available_actions, dtype=np.int8) cur_pos_str = self.shield_query_creator(self.env) - + # print(F"Pos string {cur_pos_str}") + # print(F"Shield {list(self.shield.keys())[0]}") + # print(cur_pos_str in self.shield) # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid @@ -120,6 +122,8 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): for allowed_action in allowed_actions: index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set if index is None: + print(F"No mapping for action {list(allowed_action.labels)}") + print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") assert(False) allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] @@ -134,8 +138,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): if front_tile is not None and front_tile.type == "key": mask[Actions.pickup] = 1.0 - # if self.env.carrying: - # mask[Actions.drop] = 1.0 if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 @@ -148,7 +150,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): if self.create_shield_at_reset and self.mask_actions: self.shield = self.shield_creator.create_shield(env=self.env) - self.keys = extract_keys(self.env) mask = self.create_action_mask() return { "data": obs["image"], @@ -164,7 +165,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): "action_mask": mask, } - #print(F"Info is {info}") return obs, rew, done, truncated, info @@ -204,7 +204,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): index = get_action_index_mapping(allowed_action.labels) if index is None: assert(False) - + mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] else: for index, x in enumerate(mask): @@ -222,10 +222,8 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - keys = extract_keys(self.env) shield = self.shield_creator.create_shield(env=self.env) - self.keys = keys self.shield = shield return obs["image"], infos