You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.4 KiB

2 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray import tune, air
  4. from ray.tune import register_env
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.algorithms.ppo import PPOConfig
  7. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  8. from ray.tune.logger import pretty_print
  9. from ray.rllib.models import ModelCatalog
  10. from torch_action_mask_model import TorchActionMaskModel
  11. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
  12. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  13. from callbacks import CustomCallback
  14. from torch.utils.tensorboard import SummaryWriter
  15. from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
  16. def register_minigrid_shielding_env(args):
  17. env_name = "mini-grid-shielding"
  18. register_env(env_name, shielding_env_creater)
  19. ModelCatalog.register_custom_model(
  20. "shielding_model",
  21. TorchActionMaskModel
  22. )
  23. def ppo(args):
  24. register_minigrid_shielding_env(args)
  25. config = (PPOConfig()
  26. .rollouts(num_rollout_workers=args.workers)
  27. .resources(num_gpus=0)
  28. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  29. .framework("torch")
  30. .callbacks(CustomCallback)
  31. .rl_module(_enable_rl_module_api = False)
  32. .debugging(logger_config={
  33. "type": TBXLogger,
  34. "logdir": create_log_dir(args)
  35. })
  36. .training(_enable_learner_api=False ,model={
  37. "custom_model": "shielding_model"
  38. }))
  39. return config
  40. def dqn(args):
  41. register_minigrid_shielding_env(args)
  42. config = DQNConfig()
  43. config = config.resources(num_gpus=0)
  44. config = config.rollouts(num_rollout_workers=args.workers)
  45. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  46. config = config.framework("torch")
  47. config = config.callbacks(CustomCallback)
  48. config = config.rl_module(_enable_rl_module_api = False)
  49. config = config.debugging(logger_config={
  50. "type": TBXLogger,
  51. "logdir": create_log_dir(args)
  52. })
  53. config = config.training(hiddens=[], dueling=False, model={
  54. "custom_model": "shielding_model"
  55. })
  56. return config
  57. def main():
  58. import argparse
  59. args = parse_arguments(argparse)
  60. if args.algorithm == "PPO":
  61. config = ppo(args)
  62. elif args.algorithm == "DQN":
  63. config = dqn(args)
  64. logdir = create_log_dir(args)
  65. tuner = tune.Tuner(args.algorithm,
  66. tune_config=tune.TuneConfig(
  67. metric="episode_reward_mean",
  68. mode="max",
  69. num_samples=1,
  70. ),
  71. run_config=air.RunConfig(
  72. stop = {"episode_reward_mean": 94,
  73. "timesteps_total": 12000,},
  74. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
  75. storage_path=F"{logdir}"
  76. ),
  77. param_space=config,
  78. )
  79. tuner.fit()
  80. if __name__ == '__main__':
  81. main()