|
|
@ -18,10 +18,10 @@ import matplotlib.pyplot as plt |
|
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
class ShieldInfoCallback(DefaultCallbacks): |
|
|
|
def on_episode_start(self, log_dir, data) -> None: |
|
|
|
def on_episode_start(self) -> None: |
|
|
|
file_writer = tf.summary.create_file_writer(log_dir) |
|
|
|
with file_writer.as_default(): |
|
|
|
tf.summary.text("first_text", str(data), step=0) |
|
|
|
tf.summary.text("first_text", "testing", step=0) |
|
|
|
|
|
|
|
def on_episode_step(self) -> None: |
|
|
|
pass |
|
|
@ -37,7 +37,7 @@ class MyCallbacks(DefaultCallbacks): |
|
|
|
episode.hist_data["ran_into_lava"] = [] |
|
|
|
episode.hist_data["goals_reached"] = [] |
|
|
|
episode.hist_data["ran_into_adversary"] = [] |
|
|
|
|
|
|
|
|
|
|
|
# print("On episode start print") |
|
|
|
# print(env.printGrid()) |
|
|
|
# print(worker) |
|
|
@ -47,8 +47,8 @@ class MyCallbacks(DefaultCallbacks): |
|
|
|
# print(env.observation_space) |
|
|
|
# plt.imshow(img) |
|
|
|
# plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
|
env = base_env.get_sub_environments()[0] |
|
|
@ -59,22 +59,22 @@ class MyCallbacks(DefaultCallbacks): |
|
|
|
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: |
|
|
|
print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}") |
|
|
|
# assert False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
|
env = base_env.get_sub_environments()[0] |
|
|
|
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) |
|
|
|
ran_into_adversary = False |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(env, "adversaries"): |
|
|
|
adversaries = env.adversaries.values() |
|
|
|
for adversary in adversaries: |
|
|
|
for adversary in adversaries: |
|
|
|
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: |
|
|
|
ran_into_adversary = True |
|
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") |
|
|
|
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") |
|
|
|
episode.user_data["ran_into_adversary"].append(ran_into_adversary) |
|
|
@ -86,10 +86,9 @@ class MyCallbacks(DefaultCallbacks): |
|
|
|
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] |
|
|
|
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] |
|
|
|
episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"] |
|
|
|
|
|
|
|
|
|
|
|
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: |
|
|
|
print("Evaluate Start") |
|
|
|
|
|
|
|
|
|
|
|
def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None: |
|
|
|
print("Evaluate End") |
|
|
|
|
xxxxxxxxxx