From de5f35dd5fa129e0c546e938b27bcc91d6d98aa9 Mon Sep 17 00:00:00 2001
From: sp <stefan.pranger@iaik.tugraz.at>
Date: Sat, 30 Dec 2023 12:04:21 +0100
Subject: [PATCH] moved text writing to standard callback

---
 examples/shields/rl/15_train_eval_tune.py | 4 ++--
 examples/shields/rl/callbacks.py          | 4 ++++
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py
index a71dfcc..9220af5 100644
--- a/examples/shields/rl/15_train_eval_tune.py
+++ b/examples/shields/rl/15_train_eval_tune.py
@@ -19,7 +19,7 @@ from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
 from shieldhandlers import MiniGridShieldHandler, create_shield_query
 
 from torch.utils.tensorboard import SummaryWriter
-from callbacks import MyCallbacks, ShieldInfoCallback
+from callbacks import MyCallbacks
 
 
 def shielding_env_creater(config):
@@ -79,7 +79,7 @@ def ppo(args):
                                   "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
                                   },)
         .framework("torch")
-        .callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback]))
+        .callbacks(MyCallbacks)
         .evaluation(evaluation_config={
                                        "evaluation_interval": 1,
                                         "evaluation_duration": 10,
diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py
index 221de80..7d15ca5 100644
--- a/examples/shields/rl/callbacks.py
+++ b/examples/shields/rl/callbacks.py
@@ -28,6 +28,10 @@ class ShieldInfoCallback(DefaultCallbacks):
 
 class MyCallbacks(DefaultCallbacks):
     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
+        file_writer = tf.summary.create_file_writer(log_dir)
+        with file_writer.as_default():
+            tf.summary.text("first_text", "testing", step=0)
+
         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
         env = base_env.get_sub_environments()[0]
         episode.user_data["count"] = 0