Browse Source

removed callbacks for shield info

refactoring
Stefan Pranger 11 months ago
parent
commit
f3b12f4caa
  1. 19
      examples/shields/rl/callbacks.py

19
examples/shields/rl/callbacks.py

@ -17,27 +17,8 @@ import matplotlib.pyplot as plt
import tensorflow as tf import tensorflow as tf
class ShieldInfoCallback(DefaultCallbacks):
def on_episode_start(self) -> None:
file_writer = tf.summary.create_file_writer(log_dir)
with file_writer.as_default():
tf.summary.text("first_text", "testing", step=0)
def on_episode_step(self) -> None:
pass
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
#def on_algorithm_init(self, algorithm: Algorithm, **kwargs):
# file_writer = tf.summary.FileWriter(algorithm.logdir)
# with file_writer.as_default():
# tf.summary.text("first_text", "testing", step=0)
# file_writer.flush()
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
file_writer = tf.summary.create_file_writer(f"{worker.io_context.log_dir}/shield_data")
print(file_writer.logdir)
file_writer.add_text("first_text_from_episode_start", "testing_in_episode", 0)
file_writer.flush()
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0 episode.user_data["count"] = 0

Loading…
Cancel
Save