diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py
index e66cfc1..7e5660d 100755
--- a/examples/shields/rl/11_minigridrl.py
+++ b/examples/shields/rl/11_minigridrl.py
@@ -63,6 +63,7 @@ def register_minigrid_shielding_env(args):
 
 
 def ppo(args):
+    train_batch_size = 4000
     register_minigrid_shielding_env(args)
     
     config = (PPOConfig()
@@ -77,17 +78,17 @@ def ppo(args):
             "logdir": create_log_dir(args)
         })    
         # .exploration(exploration_config={"exploration_fraction": 0.1})
-        .training(_enable_learner_api=False ,model={
-            "custom_model": "shielding_model"
-        }))
+        .training(_enable_learner_api=False ,
+            model={"custom_model": "shielding_model"},
+            train_batch_size=train_batch_size))
     # config.entropy_coeff =  0.05
     algo =(   
         config.build()
     )   
     
     
-
-    for i in range(args.evaluations):
+    iterations = int((args.steps / train_batch_size)) + 1
+    for i in range(iterations):
         result = algo.train()
         print(pretty_print(result))
 
@@ -99,6 +100,7 @@ def ppo(args):
             
 
 def dqn(args):
+    train_batch_size = 4000
     register_minigrid_shielding_env(args)
 
     
@@ -113,15 +115,16 @@ def dqn(args):
             "type": TBXLogger, 
             "logdir": create_log_dir(args)
         })
-    config = config.training(hiddens=[], dueling=False, model={    
+    config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={    
             "custom_model": "shielding_model"
     })
     
     algo = (
         config.build()
     )
-         
-    for i in range(args.evaluations):
+
+    iterations = int((args.steps / train_batch_size)) + 1
+    for i in range(iterations):
         result = algo.train()
         print(pretty_print(result))
 
diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py
index 56fa8ee..dae1c77 100644
--- a/examples/shields/rl/14_train_eval.py
+++ b/examples/shields/rl/14_train_eval.py
@@ -53,7 +53,7 @@ def register_minigrid_shielding_env(args):
 
 def ppo(args):
     register_minigrid_shielding_env(args)
-    
+    train_batch_size = 4000
     config = (PPOConfig()
         .rollouts(num_rollout_workers=args.workers)
         .resources(num_gpus=0)
@@ -74,18 +74,17 @@ def ppo(args):
         })
         .training(_enable_learner_api=False ,model={
             "custom_model": "shielding_model"      
-        }))
+        }, train_batch_size=train_batch_size))
     
     algo =(
         
         config.build()
     )
     
-    evaluations = args.evaluations
-    
     
+    iterations = int((args.steps / train_batch_size)) + 1
     
-    for i in range(evaluations):
+    for i in range(iterations):
         algo.train()
     
         if i % 5 == 0: