Browse Source

another try

refactoring
sp 11 months ago
parent
commit
e8dc44673d
  1. 5
      examples/shields/rl/callbacks.py

5
examples/shields/rl/callbacks.py

@ -28,11 +28,12 @@ class ShieldInfoCallback(DefaultCallbacks):
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
def on_algorithm_init(self, algorithm: Algorithm, **kwargs): def on_algorithm_init(self, algorithm: Algorithm, **kwargs):
self.file_writer = tf.summary.create_file_writer(algorithm.logdir)
with self.file_writer.as_default():
file_writer = tf.summary.create_file_writer(algorithm.logdir)
with file_writer.as_default():
tf.summary.text("first_text", "testing", step=0) tf.summary.text("first_text", "testing", step=0)
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
file_writer = tf.summary.create_file_writer(worker.io_context.log_dir)
with self.file_writer.as_default(): with self.file_writer.as_default():
tf.summary.text("first_text_from_episode_start", "testing_in_episode", step=0) tf.summary.text("first_text_from_episode_start", "testing_in_episode", step=0)
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")

Loading…
Cancel
Save