diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index a71dfcc..9220af5 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -19,7 +19,7 @@ from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name from shieldhandlers import MiniGridShieldHandler, create_shield_query from torch.utils.tensorboard import SummaryWriter -from callbacks import MyCallbacks, ShieldInfoCallback +from callbacks import MyCallbacks def shielding_env_creater(config): @@ -79,7 +79,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks(make_multi_callbacks([MyCallbacks, ShieldInfoCallback])) + .callbacks(MyCallbacks) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10, diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 221de80..7d15ca5 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -28,6 +28,10 @@ class ShieldInfoCallback(DefaultCallbacks): class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: + file_writer = tf.summary.create_file_writer(log_dir) + with file_writer.as_default(): + tf.summary.text("first_text", "testing", step=0) + # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0