diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 21a583d..073cebe 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -33,6 +33,9 @@ class MyCallbacks(DefaultCallbacks): tf.summary.text("first_text", "testing", step=0) def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: + file_writer = tf.summary.create_file_writer(algorithm.logdir) + with file_writer.as_default(): + tf.summary.text("first_text", "testing", step=0) # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0