You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

108 lines
3.4 KiB

  1. import gymnasium as gym
  2. import minigrid
  3. from ray import tune, air
  4. from ray.tune import register_env
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.algorithms.ppo import PPOConfig
  7. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  8. from ray.tune.logger import pretty_print
  9. from ray.rllib.models import ModelCatalog
  10. from torch_action_mask_model import TorchActionMaskModel
  11. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
  12. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  13. from callbacks import CustomCallback
  14. from torch.utils.tensorboard import SummaryWriter
  15. from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
  16. def register_minigrid_shielding_env(args):
  17. env_name = "mini-grid-shielding"
  18. register_env(env_name, shielding_env_creater)
  19. ModelCatalog.register_custom_model(
  20. "shielding_model",
  21. TorchActionMaskModel
  22. )
  23. def ppo(args):
  24. register_minigrid_shielding_env(args)
  25. config = (PPOConfig()
  26. .rollouts(num_rollout_workers=args.workers)
  27. .resources(num_gpus=0)
  28. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  29. .framework("torch")
  30. .callbacks(CustomCallback)
  31. .rl_module(_enable_rl_module_api = False)
  32. .debugging(logger_config={
  33. "type": TBXLogger,
  34. "logdir": create_log_dir(args)
  35. })
  36. .training(_enable_learner_api=False ,model={
  37. "custom_model": "shielding_model"
  38. }))
  39. return config
  40. def dqn(args):
  41. register_minigrid_shielding_env(args)
  42. config = DQNConfig()
  43. config = config.resources(num_gpus=0)
  44. config = config.rollouts(num_rollout_workers=args.workers)
  45. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  46. config = config.framework("torch")
  47. config = config.callbacks(CustomCallback)
  48. config = config.rl_module(_enable_rl_module_api = False)
  49. config = config.debugging(logger_config={
  50. "type": TBXLogger,
  51. "logdir": create_log_dir(args)
  52. })
  53. config = config.training(hiddens=[], dueling=False, model={
  54. "custom_model": "shielding_model"
  55. })
  56. return config
  57. def main():
  58. import argparse
  59. args = parse_arguments(argparse)
  60. if args.algorithm == "PPO":
  61. config = ppo(args)
  62. elif args.algorithm == "DQN":
  63. config = dqn(args)
  64. logdir = create_log_dir(args)
  65. tuner = tune.Tuner(args.algorithm,
  66. tune_config=tune.TuneConfig(
  67. metric="episode_reward_mean",
  68. mode="max",
  69. num_samples=1,
  70. ),
  71. run_config=air.RunConfig(
  72. stop = {"episode_reward_mean": 94,
  73. "timesteps_total": 12000,},
  74. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
  75. storage_path=F"{logdir}"
  76. ),
  77. param_space=config,
  78. )
  79. tuner.fit()
  80. if __name__ == '__main__':
  81. main()