You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

188 lines
5.8 KiB

from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import gymnasium as gym
import minigrid
import numpy as np
import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.models.preprocessors import get_preprocessor
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from helpers import extract_keys, parse_arguments, create_shield_dict, create_log_dir
import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print(env.printGrid())
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
#print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
# print(episode.user_data["count"])
def env_creater_custom(config):
framestack = config.get("framestack", 4)
shield = config.get("shield", {})
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_custom_minigrid_env(args):
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
def ppo(args):
ray.init(num_cpus=1)
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env})
.framework("torch")
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
}))
algo =(
config.build()
)
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def dqn(args):
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"logdir": create_log_dir(args)
})
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
algo = (
config.build()
)
for i in range(args.iterations):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
print("Saving checkpoint")
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def main():
import argparse
args = parse_arguments(argparse)
if args.algorithm == "ppo":
ppo(args)
elif args.algorithm == "dqn":
dqn(args)
if __name__ == '__main__':
main()