diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py
index 7e5660d..2d7aa02 100755
--- a/examples/shields/rl/11_minigridrl.py
+++ b/examples/shields/rl/11_minigridrl.py
@@ -9,48 +9,12 @@ from ray.rllib.models import ModelCatalog
 
 
 from torch_action_mask_model import TorchActionMaskModel
-from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
-from helpers import parse_arguments, create_log_dir, ShieldingConfig
-from shieldhandlers import MiniGridShieldHandler, create_shield_query
-from callbacks import MyCallbacks
+from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
+from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
+from callbacks import CustomCallback
 
 from ray.tune.logger import TBXLogger   
 
-def shielding_env_creater(config):
-    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
-    framestack = config.get("framestack", 4)
-    args = config.get("args", None)
-    args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
-    args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
-    
-    prob_forward = args.prob_forward
-    prob_direct = args.prob_direct
-    prob_next = args.prob_next
-
-    shield_creator = MiniGridShieldHandler(args.grid_path, 
-                                            args.grid_to_prism_binary_path,
-                                            args.prism_path, 
-                                            args.formula,
-                                            args.shield_value,
-                                            args.prism_config,
-                                            shield_comparision=args.shield_comparision)
-
-    env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
-    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, 
-                                   shield_query_creator=create_shield_query,
-                                   mask_actions=args.shielding != ShieldingConfig.Disabled,
-                                   create_shield_at_reset=args.shield_creation_at_reset)
-    # env = minigrid.wrappers.ImgObsWrapper(env)
-    # env = ImgObsWrapper(env)
-    env = OneHotShieldingWrapper(env,
-                        config.vector_index if hasattr(config, "vector_index") else 0,
-                        framestack=framestack
-                        )
-    
-    
-    return env
-
-
 
 def register_minigrid_shielding_env(args):
     env_name = "mini-grid-shielding"
@@ -71,7 +35,7 @@ def ppo(args):
         .resources(num_gpus=0)
         .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
         .framework("torch")
-        .callbacks(MyCallbacks)
+        .callbacks(CustomCallback)
         .rl_module(_enable_rl_module_api = False)
         .debugging(logger_config={
             "type": TBXLogger, 
@@ -109,7 +73,7 @@ def dqn(args):
     config = config.rollouts(num_rollout_workers=args.workers)
     config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
     config = config.framework("torch")
-    config = config.callbacks(MyCallbacks)
+    config = config.callbacks(CustomCallback)
     config = config.rl_module(_enable_rl_module_api = False)
     config = config.debugging(logger_config={
             "type": TBXLogger, 
diff --git a/examples/shields/rl/12_minigridrl_tune.py b/examples/shields/rl/12_minigridrl_tune.py
index e0ff945..e64d609 100644
--- a/examples/shields/rl/12_minigridrl_tune.py
+++ b/examples/shields/rl/12_minigridrl_tune.py
@@ -11,36 +11,13 @@ from ray.rllib.models import ModelCatalog
 
 
 from torch_action_mask_model import TorchActionMaskModel
-from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
-from helpers import parse_arguments, create_log_dir, ShieldingConfig
-from shieldhandlers import MiniGridShieldHandler, create_shield_query
-from callbacks import MyCallbacks
+from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
+from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
+from callbacks import CustomCallback
 
 from torch.utils.tensorboard import SummaryWriter
 from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
 
-def shielding_env_creater(config):
-    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
-    framestack = config.get("framestack", 4)
-    args = config.get("args", None)
-    args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
-    args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
-    
-    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
-    
-    env = gym.make(name)
-    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
-    # env = minigrid.wrappers.ImgObsWrapper(env)
-    # env = ImgObsWrapper(env)
-    env = OneHotShieldingWrapper(env,
-                        config.vector_index if hasattr(config, "vector_index") else 0,
-                        framestack=framestack
-                        )
-    
-    
-    return env
-
-
 
 def register_minigrid_shielding_env(args):
     env_name = "mini-grid-shielding"
@@ -60,7 +37,7 @@ def ppo(args):
         .resources(num_gpus=0)
         .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
         .framework("torch")
-        .callbacks(MyCallbacks)
+        .callbacks(CustomCallback)
         .rl_module(_enable_rl_module_api = False)
         .debugging(logger_config={
             "type": TBXLogger, 
@@ -83,7 +60,7 @@ def dqn(args):
     config = config.rollouts(num_rollout_workers=args.workers)
     config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
     config = config.framework("torch")
-    config = config.callbacks(MyCallbacks)
+    config = config.callbacks(CustomCallback)
     config = config.rl_module(_enable_rl_module_api = False)
     config = config.debugging(logger_config={
             "type": TBXLogger, 
diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py
index 43575dc..f04c65e 100644
--- a/examples/shields/rl/13_minigridsb.py
+++ b/examples/shields/rl/13_minigridsb.py
@@ -2,7 +2,6 @@ from sb3_contrib import MaskablePPO
 from sb3_contrib.common.maskable.evaluation import evaluate_policy
 from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
 from sb3_contrib.common.wrappers import ActionMasker
-from stable_baselines3.common.callbacks import BaseCallback
 
 import gymnasium as gym
 
@@ -10,28 +9,13 @@ from minigrid.core.actions import Actions
 
 import time
 
-from helpers import parse_arguments, create_log_dir, ShieldingConfig
-from shieldhandlers import MiniGridShieldHandler, create_shield_query
-from wrappers import MiniGridSbShieldingWrapper
-
-class CustomCallback(BaseCallback):
-    def __init__(self, verbose: int = 0, env=None):
-        super(CustomCallback, self).__init__(verbose)
-        self.env = env
-        
-        
-    def _on_step(self) -> bool:
-        print(self.env.printGrid())
-        return super()._on_step()
-
-
-
+from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
+from sb3utils import MiniGridSbShieldingWrapper
 
 def mask_fn(env: gym.Env):
     return env.create_action_mask()
     
 
-
 def main():
     import argparse
     args = parse_arguments(argparse)
@@ -44,13 +28,12 @@ def main():
     env = gym.make(args.env, render_mode="rgb_array")
     env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
     env = ActionMasker(env, mask_fn)
-    callback = CustomCallback(1, env)
     model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
     
     steps = args.steps
     
     
-    model.learn(steps, callback=callback)
+    model.learn(steps)
  
   #W  mean_reward, std_reward = evaluate_policy(model, model.get_env())
     
diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py
index dae1c77..56913a6 100644
--- a/examples/shields/rl/14_train_eval.py
+++ b/examples/shields/rl/14_train_eval.py
@@ -8,39 +8,13 @@ from ray.rllib.models import ModelCatalog
 
 
 from torch_action_mask_model import TorchActionMaskModel
-from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
-from helpers import parse_arguments, create_log_dir, ShieldingConfig
-from shieldhandlers import MiniGridShieldHandler, create_shield_query
+from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
+from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
 
-from callbacks import MyCallbacks
+from callbacks import CustomCallback
 
 from torch.utils.tensorboard import SummaryWriter
 
-
-  
-
-def shielding_env_creater(config):
-    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
-    framestack = config.get("framestack", 4)
-    args = config.get("args", None)
-    args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
-    args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
-    
-    shielding = config.get("shielding", False)
-    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
-    
-    env = gym.make(name)
-    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
-
-    env = OneHotShieldingWrapper(env,
-                        config.vector_index if hasattr(config, "vector_index") else 0,
-                        framestack=framestack
-                        )
-    
-    
-    return env
-
-
 def register_minigrid_shielding_env(args):
     env_name = "mini-grid-shielding"
     register_env(env_name, shielding_env_creater)
@@ -60,7 +34,7 @@ def ppo(args):
         .environment( env="mini-grid-shielding",
                       env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
         .framework("torch")
-        .callbacks(MyCallbacks)
+        .callbacks(CustomCallback)
         .evaluation(evaluation_config={ 
                                        "evaluation_interval": 1,
                                         "evaluation_duration": 10,
diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py
index 9220af5..ad9d5c1 100644
--- a/examples/shields/rl/15_train_eval_tune.py
+++ b/examples/shields/rl/15_train_eval_tune.py
@@ -14,43 +14,11 @@ from ray.rllib.algorithms.callbacks import make_multi_callbacks
 from ray.air import session
 
 from torch_action_mask_model import TorchActionMaskModel
-from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
-from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
-from shieldhandlers import MiniGridShieldHandler, create_shield_query
+from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
+from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig, test_name
 
 from torch.utils.tensorboard import SummaryWriter
-from callbacks import MyCallbacks
-
-
-def shielding_env_creater(config):
-    name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
-    framestack = config.get("framestack", 4)
-    args = config.get("args", None)
-    args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
-    args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
-    shielding = config.get("shielding", False)
-    shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
-                                           grid_to_prism_path=args.grid_to_prism_binary_path,
-                                           prism_path=args.prism_path,
-                                           formula=args.formula,
-                                           shield_value=args.shield_value,
-                                           prism_config=args.prism_config,
-                                           shield_comparision=args.shield_comparision)
-
-    prob_forward = args.prob_forward
-    prob_direct = args.prob_direct
-    prob_next = args.prob_next
-
-    env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
-    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
-
-    env = OneHotShieldingWrapper(env,
-                        config.vector_index if hasattr(config, "vector_index") else 0,
-                        framestack=framestack
-                        )
-
-
-    return env
+from callbacks import CustomCallback
 
 
 def register_minigrid_shielding_env(args):
@@ -79,7 +47,7 @@ def ppo(args):
                                   "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
                                   },)
         .framework("torch")
-        .callbacks(MyCallbacks)
+        .callbacks(CustomCallback)
         .evaluation(evaluation_config={
                                        "evaluation_interval": 1,
                                         "evaluation_duration": 10,
@@ -133,31 +101,6 @@ def ppo(args):
 ]
     pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
 
-   # algo = Algorithm.from_checkpoint(best_result.checkpoint)
-
-
-    # eval_log_dir = F"{logdir}-eval"
-
-    # writer = SummaryWriter(log_dir=eval_log_dir)
-    # csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
-
-
-    # for i in range(args.evaluations):
-    #     eval_result = algo.evaluate()
-    #     print(pretty_print(eval_result))
-    #     print(eval_result)
-    #     # logger.on_result(eval_result)
-
-    #     csv_logger.on_result(eval_result)
-
-    #     evaluation = eval_result['evaluation']
-    #     epsiode_reward_mean = evaluation['episode_reward_mean']
-    #     episode_len_mean = evaluation['episode_len_mean']
-    #     print(epsiode_reward_mean)
-    #     writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
-    #     writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
-
-
 def main():
     ray.init(num_cpus=3)
     import argparse
diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py
new file mode 100644
index 0000000..03b8253
--- /dev/null
+++ b/examples/shields/rl/rllibutils.py
@@ -0,0 +1,209 @@
+import gymnasium as gym
+import numpy as np
+import random
+
+from minigrid.core.actions import Actions
+from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
+
+from gymnasium.spaces import Dict, Box
+from collections import deque
+from ray.rllib.utils.numpy import one_hot
+
+from helpers import get_action_index_mapping
+from shieldhandlers import ShieldHandler
+
+
+class OneHotShieldingWrapper(gym.core.ObservationWrapper):
+    def __init__(self, env, vector_index, framestack):
+        super().__init__(env)
+        self.framestack = framestack
+        # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
+        # +4: Direction.
+        self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
+        self.init_x = None
+        self.init_y = None
+        self.x_positions = []
+        self.y_positions = []
+        self.x_y_delta_buffer = deque(maxlen=100)
+        self.vector_index = vector_index
+        self.frame_buffer = deque(maxlen=self.framestack)
+        for _ in range(self.framestack):
+            self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
+
+        self.observation_space = Dict(
+            {
+                "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
+                "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
+            }
+            )
+
+    def observation(self, obs):
+        # Debug output: max-x/y positions to watch exploration progress.
+        # print(F"Initial observation in Wrapper {obs}")
+        if self.step_count == 0:
+            for _ in range(self.framestack):
+                self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
+            if self.vector_index == 0:
+                if self.x_positions:
+                    max_diff = max(
+                        np.sqrt(
+                            (np.array(self.x_positions) - self.init_x) ** 2
+                            + (np.array(self.y_positions) - self.init_y) ** 2
+                        )
+                    )
+                    self.x_y_delta_buffer.append(max_diff)
+                    print(
+                        "100-average dist travelled={}".format(
+                            np.mean(self.x_y_delta_buffer)
+                        )
+                    )
+                    self.x_positions = []
+                    self.y_positions = []
+                self.init_x = self.agent_pos[0]
+                self.init_y = self.agent_pos[1]
+
+
+        self.x_positions.append(self.agent_pos[0])
+        self.y_positions.append(self.agent_pos[1])
+
+        image = obs["data"]
+        # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
+        objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
+        colors = one_hot(image[:, :, 1], depth=len(COLORS))
+        states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
+
+        all_ = np.concatenate([objects, colors, states], -1)
+        all_flat = np.reshape(all_, (-1,))
+        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
+        single_frame = np.concatenate([all_flat, direction])
+        self.frame_buffer.append(single_frame)
+
+        tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
+        return tmp
+
+
+class MiniGridShieldingWrapper(gym.core.Wrapper):
+    def __init__(self, 
+                 env, 
+                shield_creator : ShieldHandler,
+                shield_query_creator,
+                create_shield_at_reset=True,    
+                mask_actions=True):
+        super(MiniGridShieldingWrapper, self).__init__(env)
+        self.max_available_actions = env.action_space.n
+        self.observation_space = Dict(
+            {
+                "data": env.observation_space.spaces["image"],
+                "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
+            }
+        )
+        self.shield_creator = shield_creator
+        self.create_shield_at_reset = create_shield_at_reset
+        self.shield = shield_creator.create_shield(env=self.env)
+        self.mask_actions = mask_actions
+        self.shield_query_creator = shield_query_creator
+        print(F"Shielding is {self.mask_actions}")
+
+    def create_action_mask(self):
+        if not self.mask_actions:
+            ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
+            return ret
+        
+        cur_pos_str = self.shield_query_creator(self.env)
+        
+        # Create the mask
+        # If shield restricts action mask only valid with 1.0
+        # else set all actions as valid
+        allowed_actions = []
+        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
+
+        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
+            allowed_actions = self.shield[cur_pos_str]
+            zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
+            has_allowed_actions = False
+
+            for allowed_action in allowed_actions:
+                index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
+                if index is None:               
+                    assert(False)
+                
+                allowed =  1.0 
+                has_allowed_actions = True
+                mask[index] = allowed               
+        else:
+            for index, x in enumerate(mask):
+                mask[index] = 1.0
+        
+        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
+
+        if front_tile is not None and front_tile.type == "key":
+            mask[Actions.pickup] = 1.0
+            
+            
+        if front_tile and front_tile.type == "door":
+            mask[Actions.toggle] = 1.0
+        # print(F"Mask is {mask} State: {cur_pos_str}")
+        return mask
+
+    def reset(self, *, seed=None, options=None):
+        obs, infos = self.env.reset(seed=seed, options=options)
+        
+        if self.create_shield_at_reset and self.mask_actions:
+            self.shield = self.shield_creator.create_shield(env=self.env)
+        
+        mask = self.create_action_mask()
+        return {
+            "data": obs["image"],
+            "action_mask": mask
+        }, infos
+
+    def step(self, action):
+        orig_obs, rew, done, truncated, info = self.env.step(action)
+
+        mask = self.create_action_mask()
+        obs = {
+            "data": orig_obs["image"],
+            "action_mask": mask,
+        }
+
+        return obs, rew, done, truncated, info
+
+
+def shielding_env_creater(config):
+    name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
+    framestack = config.get("framestack", 4)
+    args = config.get("args", None)
+    args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
+    args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
+    shielding = config.get("shielding", False)
+    shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
+                                           grid_to_prism_path=args.grid_to_prism_binary_path,
+                                           prism_path=args.prism_path,
+                                           formula=args.formula,
+                                           shield_value=args.shield_value,
+                                           prism_config=args.prism_config,
+                                           shield_comparision=args.shield_comparision)
+
+    probability_intended = args.probability_intended
+    probability_displacement = args.probability_displacement
+
+    env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement)
+    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
+
+    env = OneHotShieldingWrapper(env,
+                        config.vector_index if hasattr(config, "vector_index") else 0,
+                        framestack=framestack
+                        )
+
+
+    return env
+
+  
+def register_minigrid_shielding_env(args):
+    env_name = "mini-grid-shielding"
+    register_env(env_name, shielding_env_creater)
+
+    ModelCatalog.register_custom_model(
+        "shielding_model",
+        TorchActionMaskModel
+    )
diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py
new file mode 100644
index 0000000..f0798d2
--- /dev/null
+++ b/examples/shields/rl/sb3utils.py
@@ -0,0 +1,68 @@
+import gymnasium as gym
+import numpy as np
+import random
+
+class MiniGridSbShieldingWrapper(gym.core.Wrapper):
+    def __init__(self, 
+                 env, 
+                 shield_creator : ShieldHandler,
+                 shield_query_creator,
+                 create_shield_at_reset = True,
+                 mask_actions=True,
+                 ):
+        super(MiniGridSbShieldingWrapper, self).__init__(env)
+        self.max_available_actions = env.action_space.n
+        self.observation_space = env.observation_space.spaces["image"]
+        
+        self.shield_creator = shield_creator
+        self.mask_actions = mask_actions
+        self.shield_query_creator = shield_query_creator
+
+    def create_action_mask(self):
+        if not self.mask_actions:
+            return  np.array([1.0] * self.max_available_actions, dtype=np.int8)
+               
+        cur_pos_str = self.shield_query_creator(self.env)
+        
+        allowed_actions = []
+
+        # Create the mask
+        # If shield restricts actions, mask only valid actions with 1.0
+        # else set all actions valid
+        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
+
+        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
+            allowed_actions = self.shield[cur_pos_str]
+            for allowed_action in allowed_actions:
+                index =  get_action_index_mapping(allowed_action.labels)
+                if index is None:
+                     assert(False)
+                              
+                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
+        else:
+            for index, x in enumerate(mask):
+                mask[index] = 1.0
+        
+        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
+
+            
+        if front_tile and front_tile.type == "door":
+            mask[Actions.toggle] = 1.0            
+            
+        return mask  
+    
+
+    def reset(self, *, seed=None, options=None):
+        obs, infos = self.env.reset(seed=seed, options=options)
+      
+        shield = self.shield_creator.create_shield(env=self.env)
+        
+        self.shield = shield
+        return obs["image"], infos
+
+    def step(self, action):
+        orig_obs, rew, done, truncated, info = self.env.step(action)
+        obs = orig_obs["image"]
+        
+        return obs, rew, done, truncated, info
+
diff --git a/examples/shields/rl/shieldhandlers.py b/examples/shields/rl/utils.py
similarity index 59%
rename from examples/shields/rl/shieldhandlers.py
rename to examples/shields/rl/utils.py
index 95cab72..cab65be 100644
--- a/examples/shields/rl/shieldhandlers.py
+++ b/examples/shields/rl/utils.py
@@ -78,7 +78,6 @@ class MiniGridShieldHandler(ShieldHandler):
         
         assert result.has_shield
         shield = result.shield
-        stormpy.shields.export_shield(model, shield, "Grid.shield")
         action_dictionary = {}
         shield_scheduler = shield.construct()
         state_valuations = model.state_valuations
@@ -193,4 +192,125 @@ def create_shield_query(env):
     query = f"[{agent_carrying}& {''.join(agent_key_status)}!AgentDone\t{adv_status_text}{move_text}{door_status_text}{agent_position}{adv_positions_text}{key_positions_text}]"
 
     return query
-    
\ No newline at end of file
+  
+
+class ShieldingConfig(Enum):
+    Training = 'training'
+    Evaluation = 'evaluation'
+    Disabled = 'none'
+    Full = 'full'
+    
+    def __str__(self) -> str:
+        return self.value
+
+
+def extract_keys(env):
+    keys = []
+    for j in range(env.grid.height):
+        for i in range(env.grid.width):
+            obj = env.grid.get(i,j)
+            
+            if obj and obj.type == "key":
+                keys.append((obj, i, j))
+    
+    if env.carrying and env.carrying.type == "key":
+        keys.append((env.carrying, -1, -1))
+    # TODO Maybe need to add ordering of keys so it matches the order in the shield
+    return keys
+
+def extract_doors(env):
+    doors = []
+    for j in range(env.grid.height):
+        for i in range(env.grid.width):
+            obj = env.grid.get(i,j)
+            
+            if obj and obj.type == "door":
+                doors.append(obj)
+                
+    return doors
+
+def extract_adversaries(env):
+    adv = []
+    
+    if not hasattr(env, "adversaries"):
+        return []
+    
+    for color, adversary in env.adversaries.items():
+        adv.append(adversary)
+    
+    
+    return adv
+
+def create_log_dir(args):
+    return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparision}-env:{args.env}-conf:{args.prism_config}"
+
+def test_name(args):
+    return F"{args.expname}"
+
+def get_action_index_mapping(actions):
+    for action_str in actions:
+        if not "Agent" in action_str:
+            continue
+        
+        if "move" in action_str:
+            return Actions.forward
+        elif "left" in action_str:
+            return Actions.left
+        elif "right" in action_str:
+            return Actions.right
+        elif "pickup" in action_str:
+            return Actions.pickup
+        elif "done" in action_str:
+            return Actions.done    
+        elif "drop" in action_str:
+            return Actions.drop
+        elif "toggle" in action_str:
+            return Actions.toggle
+        elif "unlock" in action_str:
+            return Actions.toggle
+    
+    raise ValueError("No action mapping found")
+    
+
+def parse_arguments(argparse):
+    parser = argparse.ArgumentParser()
+    # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
+    parser.add_argument("--env", 
+                        help="gym environment to load", 
+                        default="MiniGrid-LavaSlipperyS12-v2", 
+                        choices=[
+                                "MiniGrid-Adv-8x8-v0",
+                                "MiniGrid-AdvSimple-8x8-v0",
+                                "MiniGrid-LavaCrossingS9N1-v0",
+                                "MiniGrid-LavaCrossingS9N3-v0",
+                                "MiniGrid-LavaSlipperyS12-v0",
+                                "MiniGrid-LavaSlipperyS12-v1",
+                                "MiniGrid-LavaSlipperyS12-v2",
+                                "MiniGrid-LavaSlipperyS12-v3",
+                             
+                                ])
+    
+   # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
+    parser.add_argument("--grid_to_prism_binary_path", default="./main")
+    parser.add_argument("--grid_path", default="grid")
+    parser.add_argument("--prism_path", default="grid")
+    parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
+    parser.add_argument("--log_dir", default="../log_results/")
+    parser.add_argument("--evaluations", type=int, default=30 )
+    parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]")  # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
+    # parser.add_argument("--formula", default="<<Agent>> Pmax=? [G <= 4 !\"AgentRanIntoAdversary\"]")
+    parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--num_gpus", type=float, default=0)
+    parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
+    parser.add_argument("--steps", default=20_000, type=int)
+    parser.add_argument("--expname", default="exp")
+    parser.add_argument("--shield_creation_at_reset", action=argparse.BooleanOptionalAction)
+    parser.add_argument("--prism_config",  default=None)
+    parser.add_argument("--shield_value", default=0.9, type=float)
+    parser.add_argument("--probability_displacement", default=1/4, type=float)
+    parser.add_argument("--probability_intended", default=3/4, type=float)
+    parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
+    # parser.add_argument("--random_starts", default=1, type=int)
+    args = parser.parse_args()
+    
+    return args