|
@ -13,7 +13,7 @@ from ray.rllib.evaluation.episode_v2 import EpisodeV2 |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks |
|
|
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
|
|
|
|
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: |
|
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
episode.user_data["count"] = 0 |
|
|
episode.user_data["count"] = 0 |
|
@ -33,12 +33,12 @@ class MyCallbacks(DefaultCallbacks): |
|
|
# plt.show() |
|
|
# plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |
|
|
|
|
|
|
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
# print(env.printGrid()) |
|
|
# print(env.printGrid()) |
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) |
|
|
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) |
|
|