|
|
@ -28,14 +28,13 @@ class ShieldInfoCallback(DefaultCallbacks): |
|
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
|
def on_algorithm_init(self, algorithm: Algorithm, **kwargs): |
|
|
|
file_writer = tf.summary.create_file_writer(algorithm.logdir) |
|
|
|
with file_writer.as_default(): |
|
|
|
self.file_writer = tf.summary.create_file_writer(algorithm.logdir) |
|
|
|
with self.file_writer.as_default(): |
|
|
|
tf.summary.text("first_text", "testing", step=0) |
|
|
|
|
|
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: |
|
|
|
file_writer = tf.summary.create_file_writer(algorithm.logdir) |
|
|
|
with file_writer.as_default(): |
|
|
|
tf.summary.text("first_text", "testing", step=0) |
|
|
|
with self.file_writer.as_default(): |
|
|
|
tf.summary.text("first_text_from_episode_start", "testing_in_episode", step=0) |
|
|
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
|
|
env = base_env.get_sub_environments()[0] |
|
|
|
episode.user_data["count"] = 0 |