|
@ -1,8 +1,9 @@ |
|
|
|
|
|
|
|
|
from typing import Dict |
|
|
|
|
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
from ray.rllib.env.env_context import EnvContext |
|
|
|
|
|
|
|
|
from ray.rllib.policy import Policy |
|
|
from ray.rllib.policy import Policy |
|
|
from ray.rllib.utils.typing import PolicyID |
|
|
|
|
|
|
|
|
from ray.rllib.utils.typing import EnvType, PolicyID |
|
|
|
|
|
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
from ray.rllib.env.base_env import BaseEnv |
|
|
from ray.rllib.env.base_env import BaseEnv |
|
@ -12,6 +13,10 @@ from ray.rllib.evaluation.episode_v2 import EpisodeV2 |
|
|
|
|
|
|
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks |
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
class MyCallbacks(DefaultCallbacks): |
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: |
|
|
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: |
|
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |
|
@ -19,8 +24,11 @@ class MyCallbacks(DefaultCallbacks): |
|
|
episode.user_data["count"] = 0 |
|
|
episode.user_data["count"] = 0 |
|
|
episode.user_data["ran_into_lava"] = [] |
|
|
episode.user_data["ran_into_lava"] = [] |
|
|
episode.user_data["goals_reached"] = [] |
|
|
episode.user_data["goals_reached"] = [] |
|
|
|
|
|
episode.user_data["ran_into_adversary"] = [] |
|
|
episode.hist_data["ran_into_lava"] = [] |
|
|
episode.hist_data["ran_into_lava"] = [] |
|
|
episode.hist_data["goals_reached"] = [] |
|
|
episode.hist_data["goals_reached"] = [] |
|
|
|
|
|
episode.hist_data["ran_into_adversary"] = [] |
|
|
|
|
|
|
|
|
# print("On episode start print") |
|
|
# print("On episode start print") |
|
|
# print(env.printGrid()) |
|
|
# print(env.printGrid()) |
|
|
# print(worker) |
|
|
# print(worker) |
|
@ -28,30 +36,39 @@ class MyCallbacks(DefaultCallbacks): |
|
|
# print(env.actions) |
|
|
# print(env.actions) |
|
|
# print(env.mission) |
|
|
# print(env.mission) |
|
|
# print(env.observation_space) |
|
|
# print(env.observation_space) |
|
|
# img = env.get_frame() |
|
|
|
|
|
# plt.imshow(img) |
|
|
# plt.imshow(img) |
|
|
# plt.show() |
|
|
# plt.show() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
|
|
|
env = base_env.get_sub_environments()[0] |
|
|
|
|
|
# print(env.printGrid()) |
|
|
|
|
|
|
|
|
episode.user_data["count"] = episode.user_data["count"] + 1 |
|
|
|
|
|
env = base_env.get_sub_environments()[0] |
|
|
|
|
|
# print(env.printGrid()) |
|
|
|
|
|
|
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |
|
|
env = base_env.get_sub_environments()[0] |
|
|
env = base_env.get_sub_environments()[0] |
|
|
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) |
|
|
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) |
|
|
|
|
|
|
|
|
|
|
|
ran_into_adversary = False |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(env, "adversaries"): |
|
|
|
|
|
adversaries = env.adversaries.values() |
|
|
|
|
|
for adversary in adversaries: |
|
|
|
|
|
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: |
|
|
|
|
|
ran_into_adversary = True |
|
|
|
|
|
break |
|
|
|
|
|
|
|
|
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") |
|
|
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") |
|
|
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") |
|
|
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") |
|
|
|
|
|
episode.user_data["ran_into_adversary"].append(ran_into_adversary) |
|
|
episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal" |
|
|
episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal" |
|
|
episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava" |
|
|
episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava" |
|
|
|
|
|
episode.custom_metrics["ran_into_adversary"] = ran_into_adversary |
|
|
#print("On episode end print") |
|
|
#print("On episode end print") |
|
|
#print(env.printGrid()) |
|
|
|
|
|
|
|
|
# print(env.printGrid()) |
|
|
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] |
|
|
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] |
|
|
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] |
|
|
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] |
|
|
|
|
|
|
|
|
|
|
|
episode.hist_data["ran_into_adversary"] = episode.user_data["ran_into_adversary"] |
|
|
|
|
|
|
|
|
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: |
|
|
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: |
|
|
print("Evaluate Start") |
|
|
print("Evaluate Start") |
|
|