diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py
index 4ce06a7..9f25be6 100644
--- a/examples/shields/rl/11_minigridrl.py
+++ b/examples/shields/rl/11_minigridrl.py
@@ -1,10 +1,10 @@
-# from typing import Dict
-# from ray.rllib.env.base_env import BaseEnv
-# from ray.rllib.evaluation import RolloutWorker
-# from ray.rllib.evaluation.episode import Episode
-# from ray.rllib.evaluation.episode_v2 import EpisodeV2
-# from ray.rllib.policy import Policy
-# from ray.rllib.utils.typing import PolicyID
+from typing import Dict
+from ray.rllib.env.base_env import BaseEnv
+from ray.rllib.evaluation import RolloutWorker
+from ray.rllib.evaluation.episode import Episode
+from ray.rllib.evaluation.episode_v2 import EpisodeV2
+from ray.rllib.policy import Policy
+from ray.rllib.utils.typing import PolicyID
 
 import gymnasium as gym
 
@@ -15,47 +15,47 @@ import minigrid
 from ray.tune import register_env
 from ray.rllib.algorithms.ppo import PPOConfig
 from ray.rllib.algorithms.dqn.dqn import DQNConfig
-# from ray.rllib.algorithms.callbacks import DefaultCallbacks
+from ray.rllib.algorithms.callbacks import DefaultCallbacks
 from ray.tune.logger import pretty_print
 from ray.rllib.models import ModelCatalog
 
 
 from TorchActionMaskModel import TorchActionMaskModel
 from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
-from helpers import parse_arguments, create_log_dir
+from helpers import parse_arguments, create_log_dir, ShieldingConfig
 from ShieldHandlers import MiniGridShieldHandler
 
 import matplotlib.pyplot as plt
 
 from ray.tune.logger import TBXLogger   
 
-# class MyCallbacks(DefaultCallbacks):
-#     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
-#         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
-#         env = base_env.get_sub_environments()[0]
-#         episode.user_data["count"] = 0
-#         # print("On episode start print")
-#         # print(env.printGrid())
-#         # print(worker)
-#         # print(env.action_space.n)
-#         # print(env.actions)
-#         # print(env.mission)
-#         # print(env.observation_space)
-#         # img = env.get_frame()
-#         # plt.imshow(img)
-#         # plt.show()
+class MyCallbacks(DefaultCallbacks):
+    def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
+        # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
+        env = base_env.get_sub_environments()[0]
+        episode.user_data["count"] = 0
+        # print("On episode start print")
+        # print(env.printGrid())
+        # print(worker)
+        # print(env.action_space.n)
+        # print(env.actions)
+        # print(env.mission)
+        # print(env.observation_space)
+        # img = env.get_frame()
+        # plt.imshow(img)
+        # plt.show()
     
        
-#     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
-#          episode.user_data["count"] = episode.user_data["count"] + 1
-#          env = base_env.get_sub_environments()[0]
-#         # print(env.printGrid())
+    def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
+         episode.user_data["count"] = episode.user_data["count"] + 1
+         env = base_env.get_sub_environments()[0]
+        # print(env.printGrid())
     
-#     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
-#         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
-#         env = base_env.get_sub_environments()[0]
-#         #print("On episode end print")
-#         #print(env.printGrid())
+    def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
+        # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
+        env = base_env.get_sub_environments()[0]
+        #print("On episode end print")
+        #print(env.printGrid())
         
                     
 
@@ -83,7 +83,7 @@ def shielding_env_creater(config):
 
 
 def register_minigrid_shielding_env(args):
-    env_name = "mini-grid"
+    env_name = "mini-grid-shielding"
     register_env(env_name, shielding_env_creater)
 
     ModelCatalog.register_custom_model(
@@ -98,25 +98,21 @@ def ppo(args):
     config = (PPOConfig()
         .rollouts(num_rollout_workers=args.workers)
         .resources(num_gpus=0)
-        .environment(env="mini-grid", env_config={"name": args.env, "args": args})
+        .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
         .framework("torch")
-        #.callbacks(MyCallbacks)
+        .callbacks(MyCallbacks)
         .rl_module(_enable_rl_module_api = False)
         .debugging(logger_config={
             "type": TBXLogger, 
             "logdir": create_log_dir(args)
         })
         .training(_enable_learner_api=False ,model={
-            "custom_model": "shielding_model",
-            "custom_model_config" : {"no_masking": args.no_masking}            
+            "custom_model": "shielding_model"
         }))
     
-    algo =(
-        
+    algo =(   
         config.build()
-    )
-    
-    algo.eva
+    )    
     
     for i in range(args.iterations):
         result = algo.train()
@@ -134,7 +130,7 @@ def dqn(args):
     config = DQNConfig()
     config = config.resources(num_gpus=0)
     config = config.rollouts(num_rollout_workers=args.workers)
-    config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
+    config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
     config = config.framework("torch")
     #config = config.callbacks(MyCallbacks)
     config = config.rl_module(_enable_rl_module_api = False)
@@ -143,8 +139,7 @@ def dqn(args):
             "logdir": create_log_dir(args)
         })
     config = config.training(hiddens=[], dueling=False, model={    
-            "custom_model": "shielding_model",
-            "custom_model_config" : {"no_masking": args.no_masking}
+            "custom_model": "shielding_model"
     })
     
     algo = (
diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py
new file mode 100644
index 0000000..047873e
--- /dev/null
+++ b/examples/shields/rl/14_train_eval.py
@@ -0,0 +1,114 @@
+
+import gymnasium as gym
+
+import minigrid
+# import numpy as np
+
+# import ray
+from ray.tune import register_env
+from ray.rllib.algorithms.ppo import PPOConfig
+from ray.rllib.algorithms.dqn.dqn import DQNConfig
+# from ray.rllib.algorithms.callbacks import DefaultCallbacks
+from ray.tune.logger import pretty_print
+from ray.rllib.models import ModelCatalog
+
+
+from TorchActionMaskModel import TorchActionMaskModel
+from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
+from helpers import parse_arguments, create_log_dir, ShieldingConfig
+from ShieldHandlers import MiniGridShieldHandler
+
+import matplotlib.pyplot as plt
+
+from ray.tune.logger import TBXLogger   
+
+  
+
+def shielding_env_creater(config):
+    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
+    framestack = config.get("framestack", 4)
+    args = config.get("args", None)
+    args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
+    args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
+    
+    shielding = config.get("shielding", False)
+    
+    # if shielding:
+    #     assert(False)
+    
+    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
+    
+    env = gym.make(name)
+    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding)
+
+    env = OneHotShieldingWrapper(env,
+                        config.vector_index if hasattr(config, "vector_index") else 0,
+                        framestack=framestack
+                        )
+    
+    
+    return env
+
+
+def register_minigrid_shielding_env(args):
+    env_name = "mini-grid-shielding"
+    register_env(env_name, shielding_env_creater)
+
+    ModelCatalog.register_custom_model(
+        "shielding_model", 
+        TorchActionMaskModel
+    )
+
+
+def ppo(args):
+    register_minigrid_shielding_env(args)
+    
+    config = (PPOConfig()
+        .rollouts(num_rollout_workers=args.workers)
+        .resources(num_gpus=0)
+        .environment( env="mini-grid-shielding",
+                      env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
+        .framework("torch")
+        .evaluation(evaluation_config={ "evaluation_interval": 1,
+                                        "evaluation_parallel_to_training": False,
+                                        "env": "mini-grid-shielding", 
+                                        "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}})
+        #.callbacks(MyCallbacks)
+        .rl_module(_enable_rl_module_api = False)
+        .debugging(logger_config={
+            "type": TBXLogger, 
+            "logdir": create_log_dir(args)
+        })
+        .training(_enable_learner_api=False ,model={
+            "custom_model": "shielding_model"      
+        }))
+    
+    algo =(
+        
+        config.build()
+    )
+    
+    iterations = args.iterations
+    
+    for i in range(iterations):
+        algo.train()
+        
+        if i % 5 == 0:
+            algo.save()
+        
+    
+    for i in range(iterations):
+        eval_result = algo.evaluate()
+        print(pretty_print(eval_result))
+        
+
+def main():
+    import argparse
+    args = parse_arguments(argparse)
+
+    ppo(args)
+   
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/examples/shields/rl/TorchActionMaskModel.py b/examples/shields/rl/TorchActionMaskModel.py
index 42b6805..b478636 100644
--- a/examples/shields/rl/TorchActionMaskModel.py
+++ b/examples/shields/rl/TorchActionMaskModel.py
@@ -38,9 +38,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
             name + "_internal",
         )
         
-        self.no_masking = False
-        if "no_masking" in model_config["custom_model_config"]:
-            self.no_masking = model_config["custom_model_config"]["no_masking"]
 
     def forward(self, input_dict, state, seq_lens):
         # Extract the available actions tensor from the observation.
@@ -48,10 +45,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
         logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
    
         action_mask = input_dict["obs"]["action_mask"]
-      
-        # If action masking is disabled, directly return unmasked logits
-        if self.no_masking:
-            return logits, state
 
         inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
         masked_logits = logits + inf_mask
diff --git a/examples/shields/rl/Wrappers.py b/examples/shields/rl/Wrappers.py
index ef761fa..0e5cc05 100644
--- a/examples/shields/rl/Wrappers.py
+++ b/examples/shields/rl/Wrappers.py
@@ -82,7 +82,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
 
 
 class MiniGridShieldingWrapper(gym.core.Wrapper):
-    def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True):
+    def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True):
         super(MiniGridShieldingWrapper, self).__init__(env)
         self.max_available_actions = env.action_space.n
         self.observation_space = Dict(
@@ -94,8 +94,12 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
         self.shield_creator = shield_creator
         self.create_shield_at_reset = create_shield_at_reset
         self.shield = shield_creator.create_shield(env=self.env)
+        self.mask_actions = mask_actions
 
     def create_action_mask(self):
+        if not self.mask_actions:
+            return np.array([1.0] * self.max_available_actions, dtype=np.int8)
+        
         coordinates = self.env.agent_pos
         view_direction = self.env.agent_dir
 
@@ -146,7 +150,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
     def reset(self, *, seed=None, options=None):
         obs, infos = self.env.reset(seed=seed, options=options)
         
-        if self.create_shield_at_reset:
+        if self.create_shield_at_reset and self.mask_actions:
             self.shield = self.shield_creator.create_shield(env=self.env)
         
         self.keys = extract_keys(self.env)
diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py
index 59af016..5863d99 100644
--- a/examples/shields/rl/helpers.py
+++ b/examples/shields/rl/helpers.py
@@ -2,6 +2,9 @@ import minigrid
 from minigrid.core.actions import Actions
 
 from datetime import datetime
+from enum import Enum
+
+import os
 
 import stormpy
 import stormpy.core
@@ -13,8 +16,16 @@ import stormpy.logic
 import stormpy.examples
 import stormpy.examples.files
 
+class ShieldingConfig(Enum):
+    Training = 'training'
+    Evaluation = 'evaluation'
+    Disabled = 'none'
+    Enabled = 'full'
+    
+    def __str__(self) -> str:
+        return self.value
+
 
-   
 def extract_keys(env):
     keys = []
     #print(env.grid)
@@ -28,7 +39,7 @@ def extract_keys(env):
     return keys
 
 def create_log_dir(args):
-    return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}"
+    return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}"
 
 
 def get_action_index_mapping(actions):
@@ -77,12 +88,12 @@ def parse_arguments(argparse):
     parser.add_argument("--grid_to_prism_binary_path", default="./main")
     parser.add_argument("--grid_path", default="grid")
     parser.add_argument("--prism_path", default="grid")
-    parser.add_argument("--no_masking", default=False)
     parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
     parser.add_argument("--log_dir", default="../log_results/")
     parser.add_argument("--iterations", type=int, default=30 )
     parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]")  # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
     parser.add_argument("--workers", type=int, default=1)
+    parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled)
 
     
     args = parser.parse_args()