diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py
index ae86c40..229b49f 100644
--- a/examples/shields/rl/15_train_eval_tune.py
+++ b/examples/shields/rl/15_train_eval_tune.py
@@ -78,7 +78,7 @@ def ppo(args):
                                   "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
                                   },)
         .framework("torch")
-        .callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12]))
+        .callbacks(MyCallbacks, ShieldInfoCallback)
         .evaluation(evaluation_config={
                                        "evaluation_interval": 1,
                                         "evaluation_duration": 10,
diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py
index 1be3ecc..221de80 100644
--- a/examples/shields/rl/callbacks.py
+++ b/examples/shields/rl/callbacks.py
@@ -18,10 +18,10 @@ import matplotlib.pyplot as plt
 import tensorflow as tf
 
 class ShieldInfoCallback(DefaultCallbacks):
-    def on_episode_start(self, log_dir, data) -> None:
+    def on_episode_start(self) -> None:
         file_writer = tf.summary.create_file_writer(log_dir)
         with file_writer.as_default():
-            tf.summary.text("first_text", str(data), step=0)
+            tf.summary.text("first_text", "testing", step=0)
 
     def on_episode_step(self) -> None:
         pass
@@ -37,7 +37,7 @@ class MyCallbacks(DefaultCallbacks):
         episode.hist_data["ran_into_lava"] = []
         episode.hist_data["goals_reached"] = []
         episode.hist_data["ran_into_adversary"] = []
-        
+
         # print("On episode start print")
         # print(env.printGrid())
         # print(worker)
@@ -47,8 +47,8 @@ class MyCallbacks(DefaultCallbacks):
         # print(env.observation_space)
         # plt.imshow(img)
         # plt.show()
-    
-     
+
+
     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
         episode.user_data["count"] = episode.user_data["count"] + 1
         env = base_env.get_sub_environments()[0]
@@ -59,22 +59,22 @@ class MyCallbacks(DefaultCallbacks):
                 if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
                     print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}")
                     # assert False
-                
-         
-    
+
+
+
     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
         env = base_env.get_sub_environments()[0]
         agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
         ran_into_adversary = False
-        
+
         if hasattr(env, "adversaries"):
             adversaries = env.adversaries.values()
-            for adversary in adversaries:              
+            for adversary in adversaries:
                 if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
                     ran_into_adversary = True
                     break
-        
+
         episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
         episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
         episode.user_data["ran_into_adversary"].append(ran_into_adversary)
@@ -86,10 +86,9 @@ class MyCallbacks(DefaultCallbacks):
         episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
         episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
         episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"]
-        
+
     def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
         print("Evaluate Start")
-        
+
     def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
         print("Evaluate End")
-