From f35083e669a1e894721a00a8f56755a703c00d8a Mon Sep 17 00:00:00 2001 From: sp Date: Sat, 30 Dec 2023 12:23:54 +0100 Subject: [PATCH] callback on algo init --- examples/shields/rl/callbacks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 7d15ca5..ac4370d 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -27,11 +27,12 @@ class ShieldInfoCallback(DefaultCallbacks): pass class MyCallbacks(DefaultCallbacks): - def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: - file_writer = tf.summary.create_file_writer(log_dir) + def on_algorithm_init(self, algorithm: Algorithm, **kwargs): + file_writer = tf.summary.create_file_writer(algorithm.log_dir) with file_writer.as_default(): tf.summary.text("first_text", "testing", step=0) + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0