You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.4 KiB

2 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from torch_action_mask_model import TorchActionMaskModel
  9. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
  10. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  11. from callbacks import CustomCallback
  12. from ray.tune.logger import TBXLogger
  13. def register_minigrid_shielding_env(args):
  14. env_name = "mini-grid-shielding"
  15. register_env(env_name, shielding_env_creater)
  16. ModelCatalog.register_custom_model(
  17. "shielding_model",
  18. TorchActionMaskModel
  19. )
  20. def ppo(args):
  21. train_batch_size = 4000
  22. register_minigrid_shielding_env(args)
  23. config = (PPOConfig()
  24. .rollouts(num_rollout_workers=args.workers)
  25. .resources(num_gpus=0)
  26. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  27. .framework("torch")
  28. .callbacks(CustomCallback)
  29. .rl_module(_enable_rl_module_api = False)
  30. .debugging(logger_config={
  31. "type": TBXLogger,
  32. "logdir": create_log_dir(args)
  33. })
  34. # .exploration(exploration_config={"exploration_fraction": 0.1})
  35. .training(_enable_learner_api=False ,
  36. model={"custom_model": "shielding_model"},
  37. train_batch_size=train_batch_size))
  38. # config.entropy_coeff = 0.05
  39. algo =(
  40. config.build()
  41. )
  42. iterations = int((args.steps / train_batch_size)) + 1
  43. for i in range(iterations):
  44. result = algo.train()
  45. print(pretty_print(result))
  46. if i % 5 == 0:
  47. checkpoint_dir = algo.save()
  48. print(f"Checkpoint saved in directory {checkpoint_dir}")
  49. algo.save()
  50. def dqn(args):
  51. train_batch_size = 4000
  52. register_minigrid_shielding_env(args)
  53. config = DQNConfig()
  54. config = config.resources(num_gpus=0)
  55. config = config.rollouts(num_rollout_workers=args.workers)
  56. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  57. config = config.framework("torch")
  58. config = config.callbacks(CustomCallback)
  59. config = config.rl_module(_enable_rl_module_api = False)
  60. config = config.debugging(logger_config={
  61. "type": TBXLogger,
  62. "logdir": create_log_dir(args)
  63. })
  64. config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={
  65. "custom_model": "shielding_model"
  66. })
  67. algo = (
  68. config.build()
  69. )
  70. iterations = int((args.steps / train_batch_size)) + 1
  71. for i in range(iterations):
  72. result = algo.train()
  73. print(pretty_print(result))
  74. if i % 5 == 0:
  75. print("Saving checkpoint")
  76. checkpoint_dir = algo.save()
  77. print(f"Checkpoint saved in directory {checkpoint_dir}")
  78. def main():
  79. import argparse
  80. args = parse_arguments(argparse)
  81. if args.algorithm == "PPO":
  82. ppo(args)
  83. elif args.algorithm == "DQN":
  84. dqn(args)
  85. if __name__ == '__main__':
  86. main()