8 changed files with 420 additions and 182 deletions
			
			
		- 
					46examples/shields/rl/11_minigridrl.py
 - 
					33examples/shields/rl/12_minigridrl_tune.py
 - 
					23examples/shields/rl/13_minigridsb.py
 - 
					34examples/shields/rl/14_train_eval.py
 - 
					65examples/shields/rl/15_train_eval_tune.py
 - 
					209examples/shields/rl/rllibutils.py
 - 
					68examples/shields/rl/sb3utils.py
 - 
					124examples/shields/rl/utils.py
 
@ -0,0 +1,209 @@ | 
				
			|||
import gymnasium as gym | 
				
			|||
import numpy as np | 
				
			|||
import random | 
				
			|||
 | 
				
			|||
from minigrid.core.actions import Actions | 
				
			|||
from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX | 
				
			|||
 | 
				
			|||
from gymnasium.spaces import Dict, Box | 
				
			|||
from collections import deque | 
				
			|||
from ray.rllib.utils.numpy import one_hot | 
				
			|||
 | 
				
			|||
from helpers import get_action_index_mapping | 
				
			|||
from shieldhandlers import ShieldHandler | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class OneHotShieldingWrapper(gym.core.ObservationWrapper): | 
				
			|||
    def __init__(self, env, vector_index, framestack): | 
				
			|||
        super().__init__(env) | 
				
			|||
        self.framestack = framestack | 
				
			|||
        # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types. | 
				
			|||
        # +4: Direction. | 
				
			|||
        self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4 | 
				
			|||
        self.init_x = None | 
				
			|||
        self.init_y = None | 
				
			|||
        self.x_positions = [] | 
				
			|||
        self.y_positions = [] | 
				
			|||
        self.x_y_delta_buffer = deque(maxlen=100) | 
				
			|||
        self.vector_index = vector_index | 
				
			|||
        self.frame_buffer = deque(maxlen=self.framestack) | 
				
			|||
        for _ in range(self.framestack): | 
				
			|||
            self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | 
				
			|||
 | 
				
			|||
        self.observation_space = Dict( | 
				
			|||
            { | 
				
			|||
                "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), | 
				
			|||
                "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), | 
				
			|||
            } | 
				
			|||
            ) | 
				
			|||
 | 
				
			|||
    def observation(self, obs): | 
				
			|||
        # Debug output: max-x/y positions to watch exploration progress. | 
				
			|||
        # print(F"Initial observation in Wrapper {obs}") | 
				
			|||
        if self.step_count == 0: | 
				
			|||
            for _ in range(self.framestack): | 
				
			|||
                self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | 
				
			|||
            if self.vector_index == 0: | 
				
			|||
                if self.x_positions: | 
				
			|||
                    max_diff = max( | 
				
			|||
                        np.sqrt( | 
				
			|||
                            (np.array(self.x_positions) - self.init_x) ** 2 | 
				
			|||
                            + (np.array(self.y_positions) - self.init_y) ** 2 | 
				
			|||
                        ) | 
				
			|||
                    ) | 
				
			|||
                    self.x_y_delta_buffer.append(max_diff) | 
				
			|||
                    print( | 
				
			|||
                        "100-average dist travelled={}".format( | 
				
			|||
                            np.mean(self.x_y_delta_buffer) | 
				
			|||
                        ) | 
				
			|||
                    ) | 
				
			|||
                    self.x_positions = [] | 
				
			|||
                    self.y_positions = [] | 
				
			|||
                self.init_x = self.agent_pos[0] | 
				
			|||
                self.init_y = self.agent_pos[1] | 
				
			|||
 | 
				
			|||
 | 
				
			|||
        self.x_positions.append(self.agent_pos[0]) | 
				
			|||
        self.y_positions.append(self.agent_pos[1]) | 
				
			|||
 | 
				
			|||
        image = obs["data"] | 
				
			|||
        # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten. | 
				
			|||
        objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX)) | 
				
			|||
        colors = one_hot(image[:, :, 1], depth=len(COLORS)) | 
				
			|||
        states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX)) | 
				
			|||
 | 
				
			|||
        all_ = np.concatenate([objects, colors, states], -1) | 
				
			|||
        all_flat = np.reshape(all_, (-1,)) | 
				
			|||
        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) | 
				
			|||
        single_frame = np.concatenate([all_flat, direction]) | 
				
			|||
        self.frame_buffer.append(single_frame) | 
				
			|||
 | 
				
			|||
        tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } | 
				
			|||
        return tmp | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class MiniGridShieldingWrapper(gym.core.Wrapper): | 
				
			|||
    def __init__(self,  | 
				
			|||
                 env,  | 
				
			|||
                shield_creator : ShieldHandler, | 
				
			|||
                shield_query_creator, | 
				
			|||
                create_shield_at_reset=True,     | 
				
			|||
                mask_actions=True): | 
				
			|||
        super(MiniGridShieldingWrapper, self).__init__(env) | 
				
			|||
        self.max_available_actions = env.action_space.n | 
				
			|||
        self.observation_space = Dict( | 
				
			|||
            { | 
				
			|||
                "data": env.observation_space.spaces["image"], | 
				
			|||
                "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), | 
				
			|||
            } | 
				
			|||
        ) | 
				
			|||
        self.shield_creator = shield_creator | 
				
			|||
        self.create_shield_at_reset = create_shield_at_reset | 
				
			|||
        self.shield = shield_creator.create_shield(env=self.env) | 
				
			|||
        self.mask_actions = mask_actions | 
				
			|||
        self.shield_query_creator = shield_query_creator | 
				
			|||
        print(F"Shielding is {self.mask_actions}") | 
				
			|||
 | 
				
			|||
    def create_action_mask(self): | 
				
			|||
        if not self.mask_actions: | 
				
			|||
            ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) | 
				
			|||
            return ret | 
				
			|||
         | 
				
			|||
        cur_pos_str = self.shield_query_creator(self.env) | 
				
			|||
         | 
				
			|||
        # Create the mask | 
				
			|||
        # If shield restricts action mask only valid with 1.0 | 
				
			|||
        # else set all actions as valid | 
				
			|||
        allowed_actions = [] | 
				
			|||
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) | 
				
			|||
 | 
				
			|||
        if cur_pos_str in self.shield and self.shield[cur_pos_str]: | 
				
			|||
            allowed_actions = self.shield[cur_pos_str] | 
				
			|||
            zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8) | 
				
			|||
            has_allowed_actions = False | 
				
			|||
 | 
				
			|||
            for allowed_action in allowed_actions: | 
				
			|||
                index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set | 
				
			|||
                if index is None:                | 
				
			|||
                    assert(False) | 
				
			|||
                 | 
				
			|||
                allowed =  1.0  | 
				
			|||
                has_allowed_actions = True | 
				
			|||
                mask[index] = allowed                | 
				
			|||
        else: | 
				
			|||
            for index, x in enumerate(mask): | 
				
			|||
                mask[index] = 1.0 | 
				
			|||
         | 
				
			|||
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) | 
				
			|||
 | 
				
			|||
        if front_tile is not None and front_tile.type == "key": | 
				
			|||
            mask[Actions.pickup] = 1.0 | 
				
			|||
             | 
				
			|||
             | 
				
			|||
        if front_tile and front_tile.type == "door": | 
				
			|||
            mask[Actions.toggle] = 1.0 | 
				
			|||
        # print(F"Mask is {mask} State: {cur_pos_str}") | 
				
			|||
        return mask | 
				
			|||
 | 
				
			|||
    def reset(self, *, seed=None, options=None): | 
				
			|||
        obs, infos = self.env.reset(seed=seed, options=options) | 
				
			|||
         | 
				
			|||
        if self.create_shield_at_reset and self.mask_actions: | 
				
			|||
            self.shield = self.shield_creator.create_shield(env=self.env) | 
				
			|||
         | 
				
			|||
        mask = self.create_action_mask() | 
				
			|||
        return { | 
				
			|||
            "data": obs["image"], | 
				
			|||
            "action_mask": mask | 
				
			|||
        }, infos | 
				
			|||
 | 
				
			|||
    def step(self, action): | 
				
			|||
        orig_obs, rew, done, truncated, info = self.env.step(action) | 
				
			|||
 | 
				
			|||
        mask = self.create_action_mask() | 
				
			|||
        obs = { | 
				
			|||
            "data": orig_obs["image"], | 
				
			|||
            "action_mask": mask, | 
				
			|||
        } | 
				
			|||
 | 
				
			|||
        return obs, rew, done, truncated, info | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def shielding_env_creater(config): | 
				
			|||
    name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") | 
				
			|||
    framestack = config.get("framestack", 4) | 
				
			|||
    args = config.get("args", None) | 
				
			|||
    args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" | 
				
			|||
    args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" | 
				
			|||
    shielding = config.get("shielding", False) | 
				
			|||
    shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, | 
				
			|||
                                           grid_to_prism_path=args.grid_to_prism_binary_path, | 
				
			|||
                                           prism_path=args.prism_path, | 
				
			|||
                                           formula=args.formula, | 
				
			|||
                                           shield_value=args.shield_value, | 
				
			|||
                                           prism_config=args.prism_config, | 
				
			|||
                                           shield_comparision=args.shield_comparision) | 
				
			|||
 | 
				
			|||
    probability_intended = args.probability_intended | 
				
			|||
    probability_displacement = args.probability_displacement | 
				
			|||
 | 
				
			|||
    env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) | 
				
			|||
    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) | 
				
			|||
 | 
				
			|||
    env = OneHotShieldingWrapper(env, | 
				
			|||
                        config.vector_index if hasattr(config, "vector_index") else 0, | 
				
			|||
                        framestack=framestack | 
				
			|||
                        ) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
    return env | 
				
			|||
 | 
				
			|||
   | 
				
			|||
def register_minigrid_shielding_env(args): | 
				
			|||
    env_name = "mini-grid-shielding" | 
				
			|||
    register_env(env_name, shielding_env_creater) | 
				
			|||
 | 
				
			|||
    ModelCatalog.register_custom_model( | 
				
			|||
        "shielding_model", | 
				
			|||
        TorchActionMaskModel | 
				
			|||
    ) | 
				
			|||
@ -0,0 +1,68 @@ | 
				
			|||
import gymnasium as gym | 
				
			|||
import numpy as np | 
				
			|||
import random | 
				
			|||
 | 
				
			|||
class MiniGridSbShieldingWrapper(gym.core.Wrapper): | 
				
			|||
    def __init__(self,  | 
				
			|||
                 env,  | 
				
			|||
                 shield_creator : ShieldHandler, | 
				
			|||
                 shield_query_creator, | 
				
			|||
                 create_shield_at_reset = True, | 
				
			|||
                 mask_actions=True, | 
				
			|||
                 ): | 
				
			|||
        super(MiniGridSbShieldingWrapper, self).__init__(env) | 
				
			|||
        self.max_available_actions = env.action_space.n | 
				
			|||
        self.observation_space = env.observation_space.spaces["image"] | 
				
			|||
         | 
				
			|||
        self.shield_creator = shield_creator | 
				
			|||
        self.mask_actions = mask_actions | 
				
			|||
        self.shield_query_creator = shield_query_creator | 
				
			|||
 | 
				
			|||
    def create_action_mask(self): | 
				
			|||
        if not self.mask_actions: | 
				
			|||
            return  np.array([1.0] * self.max_available_actions, dtype=np.int8) | 
				
			|||
                | 
				
			|||
        cur_pos_str = self.shield_query_creator(self.env) | 
				
			|||
         | 
				
			|||
        allowed_actions = [] | 
				
			|||
 | 
				
			|||
        # Create the mask | 
				
			|||
        # If shield restricts actions, mask only valid actions with 1.0 | 
				
			|||
        # else set all actions valid | 
				
			|||
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) | 
				
			|||
 | 
				
			|||
        if cur_pos_str in self.shield and self.shield[cur_pos_str]: | 
				
			|||
            allowed_actions = self.shield[cur_pos_str] | 
				
			|||
            for allowed_action in allowed_actions: | 
				
			|||
                index =  get_action_index_mapping(allowed_action.labels) | 
				
			|||
                if index is None: | 
				
			|||
                     assert(False) | 
				
			|||
                               | 
				
			|||
                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] | 
				
			|||
        else: | 
				
			|||
            for index, x in enumerate(mask): | 
				
			|||
                mask[index] = 1.0 | 
				
			|||
         | 
				
			|||
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) | 
				
			|||
 | 
				
			|||
             | 
				
			|||
        if front_tile and front_tile.type == "door": | 
				
			|||
            mask[Actions.toggle] = 1.0             | 
				
			|||
             | 
				
			|||
        return mask   | 
				
			|||
     | 
				
			|||
 | 
				
			|||
    def reset(self, *, seed=None, options=None): | 
				
			|||
        obs, infos = self.env.reset(seed=seed, options=options) | 
				
			|||
       | 
				
			|||
        shield = self.shield_creator.create_shield(env=self.env) | 
				
			|||
         | 
				
			|||
        self.shield = shield | 
				
			|||
        return obs["image"], infos | 
				
			|||
 | 
				
			|||
    def step(self, action): | 
				
			|||
        orig_obs, rew, done, truncated, info = self.env.step(action) | 
				
			|||
        obs = orig_obs["image"] | 
				
			|||
         | 
				
			|||
        return obs, rew, done, truncated, info | 
				
			|||
 | 
				
			|||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue