You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.6 KiB

11 months ago
12 months ago
  1. from typing import Dict, Optional
  2. from ray.rllib.env.env_context import EnvContext
  3. from ray.rllib.policy import Policy
  4. from ray.rllib.utils.typing import EnvType, PolicyID
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.env.base_env import BaseEnv
  7. from ray.rllib.evaluation import RolloutWorker
  8. from ray.rllib.evaluation.episode import Episode
  9. from ray.rllib.evaluation.episode_v2 import EpisodeV2
  10. from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
  11. import matplotlib.pyplot as plt
  12. import tensorflow as tf
  13. class ShieldInfoCallback(DefaultCallbacks):
  14. def on_episode_start(self) -> None:
  15. file_writer = tf.summary.create_file_writer(log_dir)
  16. with file_writer.as_default():
  17. tf.summary.text("first_text", "testing", step=0)
  18. def on_episode_step(self) -> None:
  19. pass
  20. class MyCallbacks(DefaultCallbacks):
  21. #def on_algorithm_init(self, algorithm: Algorithm, **kwargs):
  22. # file_writer = tf.summary.FileWriter(algorithm.logdir)
  23. # with file_writer.as_default():
  24. # tf.summary.text("first_text", "testing", step=0)
  25. # file_writer.flush()
  26. def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
  27. file_writer = tf.summary.create_file_writer(f"{worker.io_context.log_dir}/shield_data")
  28. print(file_writer.logdir)
  29. file_writer.add_text("first_text_from_episode_start", "testing_in_episode", 0)
  30. file_writer.flush()
  31. # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
  32. env = base_env.get_sub_environments()[0]
  33. episode.user_data["count"] = 0
  34. episode.user_data["ran_into_lava"] = []
  35. episode.user_data["goals_reached"] = []
  36. episode.user_data["ran_into_adversary"] = []
  37. episode.hist_data["ran_into_lava"] = []
  38. episode.hist_data["goals_reached"] = []
  39. episode.hist_data["ran_into_adversary"] = []
  40. # print("On episode start print")
  41. # print(env.printGrid())
  42. # print(worker)
  43. # print(env.action_space.n)
  44. # print(env.actions)
  45. # print(env.mission)
  46. # print(env.observation_space)
  47. # plt.imshow(img)
  48. # plt.show()
  49. def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
  50. episode.user_data["count"] = episode.user_data["count"] + 1
  51. env = base_env.get_sub_environments()[0]
  52. # print(env.printGrid())
  53. if hasattr(env, "adversaries"):
  54. for adversary in env.adversaries.values():
  55. if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
  56. print(F"Adversary ran into agent. Adversary {adversary.cur_pos}, Agent {env.agent_pos}")
  57. # assert False
  58. def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
  59. # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
  60. env = base_env.get_sub_environments()[0]
  61. agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
  62. ran_into_adversary = False
  63. if hasattr(env, "adversaries"):
  64. adversaries = env.adversaries.values()
  65. for adversary in adversaries:
  66. if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
  67. ran_into_adversary = True
  68. break
  69. episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
  70. episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
  71. episode.user_data["ran_into_adversary"].append(ran_into_adversary)
  72. episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal"
  73. episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava"
  74. episode.custom_metrics["ran_into_adversary"] = ran_into_adversary
  75. #print("On episode end print")
  76. # print(env.printGrid())
  77. episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
  78. episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
  79. episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"]
  80. def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
  81. print("Evaluate Start")
  82. def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
  83. print("Evaluate End")