from typing import Dict from ray.rllib.env.base_env import BaseEnv from ray.rllib.evaluation import RolloutWorker from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.policy import Policy from ray.rllib.utils.typing import PolicyID from ray.rllib.algorithms.algorithm import Algorithm import gymnasium as gym import minigrid import numpy as np import ray from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.models import ModelCatalog from ray.rllib.utils.torch_utils import FLOAT_MIN from ray.rllib.models.preprocessors import get_preprocessor from MaskModels import TorchActionMaskModel from Wrapper import OneHotWrapper, MiniGridEnvWrapper from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir import matplotlib.pyplot as plt class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] episode.user_data["count"] = 0 # print(env.printGrid()) # print(env.action_space.n) # print(env.actions) # print(env.mission) # print(env.observation_space) # img = env.get_frame() # plt.imshow(img) # plt.show() def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] # print(env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") env = base_env.get_sub_environments()[0] # print(env.printGrid()) # print(episode.user_data["count"]) def env_creater_custom(config): framestack = config.get("framestack", 4) shield = config.get("shield", {}) name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) args = config.get("args", None) env = gym.make(name) env = MiniGridEnvWrapper(env, args=args) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, framestack=framestack ) return env def register_custom_minigrid_env(args): env_name = "mini-grid" register_env(env_name, env_creater_custom) ModelCatalog.register_custom_model( "pa_model", TorchActionMaskModel ) def ppo(args): ray.init(num_cpus=1) register_custom_minigrid_env(args) config = (PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) .environment(env="mini-grid", env_config={"name": args.env, "args": args}) .framework("torch") .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": "ray.tune.logger.TBXLogger", "logdir": create_log_dir(args) }) .training(_enable_learner_api=False ,model={ "custom_model": "pa_model", "custom_model_config" : {"no_masking": args.no_masking} })) algo =( config.build() ) for i in range(args.iterations): result = algo.train() print(pretty_print(result)) if i % 5 == 0: checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") # terminated = truncated = False # while not terminated and not truncated: # action = algo.compute_single_action(obs) # obs, reward, terminated, truncated = env.step(action) ray.shutdown() def dqn(args): register_custom_minigrid_env(args) config = DQNConfig() config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=1) config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.framework("torch") config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ "type": "ray.tune.logger.TBXLogger", "logdir": create_log_dir(args) }) config = config.training(hiddens=[], dueling=False, model={ "custom_model": "pa_model", "custom_model_config" : {"no_masking": args.no_masking} }) algo = ( config.build() ) for i in range(args.iterations): result = algo.train() print(pretty_print(result)) if i % 5 == 0: print("Saving checkpoint") checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") ray.shutdown() def main(): import argparse args = parse_arguments(argparse) if args.algorithm == "ppo": ppo(args) elif args.algorithm == "dqn": dqn(args) if __name__ == '__main__': main()