From 07841090c62b1bf73ebb5a4a0ea5079a98b08622 Mon Sep 17 00:00:00 2001 From: sp Date: Sat, 30 Dec 2023 11:46:16 +0100 Subject: [PATCH] sh info callback, removed ws --- examples/shields/rl/15_train_eval_tune.py | 2 +- examples/shields/rl/callbacks.py | 27 +++++++++++------------ 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index ae86c40..229b49f 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -78,7 +78,7 @@ def ppo(args): "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, },) .framework("torch") - .callbacks(MyCallbacks, ShieldInfoCallback(logdir, [1,12])) + .callbacks(MyCallbacks, ShieldInfoCallback) .evaluation(evaluation_config={ "evaluation_interval": 1, "evaluation_duration": 10, diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 1be3ecc..221de80 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -18,10 +18,10 @@ import matplotlib.pyplot as plt import tensorflow as tf class ShieldInfoCallback(DefaultCallbacks): - def on_episode_start(self, log_dir, data) -> None: + def on_episode_start(self) -> None: file_writer = tf.summary.create_file_writer(log_dir) with file_writer.as_default(): - tf.summary.text("first_text", str(data), step=0) + tf.summary.text("first_text", "testing", step=0) def on_episode_step(self) -> None: pass @@ -37,7 +37,7 @@ class MyCallbacks(DefaultCallbacks): episode.hist_data["ran_into_lava"] = [] episode.hist_data["goals_reached"] = [] episode.hist_data["ran_into_adversary"] = [] - + # print("On episode start print") # print(env.printGrid()) # print(worker) @@ -47,8 +47,8 @@ class MyCallbacks(DefaultCallbacks): # print(env.observation_space) # plt.imshow(img) # plt.show() - - + + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] @@ -59,22 +59,22 @@ class MyCallbacks(DefaultCallbacks): if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}") # assert False - - - + + + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) ran_into_adversary = False - + if hasattr(env, "adversaries"): adversaries = env.adversaries.values() - for adversary in adversaries: + for adversary in adversaries: if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: ran_into_adversary = True break - + episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") episode.user_data["ran_into_adversary"].append(ran_into_adversary) @@ -86,10 +86,9 @@ class MyCallbacks(DefaultCallbacks): episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"] - + def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: print("Evaluate Start") - + def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None: print("Evaluate End") -