You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
3.2 KiB

11 months ago
  1. from typing import Dict, Optional
  2. from ray.rllib.env.env_context import EnvContext
  3. from ray.rllib.policy import Policy
  4. from ray.rllib.utils.typing import EnvType, PolicyID
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.env.base_env import BaseEnv
  7. from ray.rllib.evaluation import RolloutWorker
  8. from ray.rllib.evaluation.episode import Episode
  9. from ray.rllib.evaluation.episode_v2 import EpisodeV2
  10. from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
  11. from ray.tune import Callback
  12. import matplotlib.pyplot as plt
  13. class CustomCallback(DefaultCallbacks):
  14. def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
  15. env = base_env.get_sub_environments()[0]
  16. episode.user_data["count"] = 0
  17. episode.user_data["ran_into_lava"] = []
  18. episode.user_data["goals_reached"] = []
  19. episode.user_data["ran_into_adversary"] = []
  20. episode.hist_data["ran_into_lava"] = []
  21. episode.hist_data["goals_reached"] = []
  22. episode.hist_data["ran_into_adversary"] = []
  23. def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
  24. episode.user_data["count"] = episode.user_data["count"] + 1
  25. env = base_env.get_sub_environments()[0]
  26. # print(env.printGrid())
  27. def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
  28. # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
  29. env = base_env.get_sub_environments()[0]
  30. agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
  31. ran_into_adversary = False
  32. if hasattr(env, "adversaries") and env.adversaries:
  33. adversaries = env.adversaries.values()
  34. for adversary in adversaries:
  35. if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:
  36. ran_into_adversary = True
  37. break
  38. episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
  39. episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
  40. episode.user_data["ran_into_adversary"].append(ran_into_adversary)
  41. episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal"
  42. episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava"
  43. episode.custom_metrics["ran_into_adversary"] = ran_into_adversary
  44. #print("On episode end print")
  45. # print(env.printGrid())
  46. episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
  47. episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
  48. episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"]
  49. def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
  50. print("Evaluate Start")
  51. def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
  52. print("Evaluate End")