diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 814793a..03c724d 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -27,9 +27,18 @@ class ShieldInfoCallback(DefaultCallbacks): pass class MyCallbacks(DefaultCallbacks): + #def on_algorithm_init(self, algorithm: Algorithm, **kwargs): + # file_writer = tf.summary.FileWriter(algorithm.logdir) + # with file_writer.as_default(): + # tf.summary.text("first_text", "testing", step=0) + # file_writer.flush() + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: - with open(f"{worker.io_context.log_dir}/testing.txt", "a") as file: - file.write("first_text_from_episode_start\n") + file_writer = tf.summary.create_file_writer(f"{worker.io_context.log_dir}/shield_data") + print(file_writer) + with file_writer.as_default(): + tf.summary.text("first_text_from_episode_start", "testing_in_episode", step=0) + file_writer.flush() # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0