You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.4 KiB

11 months ago
11 months ago
11 months ago
11 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from torch_action_mask_model import TorchActionMaskModel
  9. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
  10. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  11. from callbacks import CustomCallback
  12. from ray.tune.logger import TBXLogger
  13. def register_minigrid_shielding_env(args):
  14. env_name = "mini-grid-shielding"
  15. register_env(env_name, shielding_env_creater)
  16. ModelCatalog.register_custom_model(
  17. "shielding_model",
  18. TorchActionMaskModel
  19. )
  20. def ppo(args):
  21. train_batch_size = 4000
  22. register_minigrid_shielding_env(args)
  23. config = (PPOConfig()
  24. .rollouts(num_rollout_workers=args.workers)
  25. .resources(num_gpus=0)
  26. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  27. .framework("torch")
  28. .callbacks(CustomCallback)
  29. .rl_module(_enable_rl_module_api = False)
  30. .debugging(logger_config={
  31. "type": TBXLogger,
  32. "logdir": create_log_dir(args)
  33. })
  34. # .exploration(exploration_config={"exploration_fraction": 0.1})
  35. .training(_enable_learner_api=False ,
  36. model={"custom_model": "shielding_model"},
  37. train_batch_size=train_batch_size))
  38. # config.entropy_coeff = 0.05
  39. algo =(
  40. config.build()
  41. )
  42. iterations = int((args.steps / train_batch_size)) + 1
  43. for i in range(iterations):
  44. result = algo.train()
  45. print(pretty_print(result))
  46. if i % 5 == 0:
  47. checkpoint_dir = algo.save()
  48. print(f"Checkpoint saved in directory {checkpoint_dir}")
  49. algo.save()
  50. def dqn(args):
  51. train_batch_size = 4000
  52. register_minigrid_shielding_env(args)
  53. config = DQNConfig()
  54. config = config.resources(num_gpus=0)
  55. config = config.rollouts(num_rollout_workers=args.workers)
  56. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  57. config = config.framework("torch")
  58. config = config.callbacks(CustomCallback)
  59. config = config.rl_module(_enable_rl_module_api = False)
  60. config = config.debugging(logger_config={
  61. "type": TBXLogger,
  62. "logdir": create_log_dir(args)
  63. })
  64. config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={
  65. "custom_model": "shielding_model"
  66. })
  67. algo = (
  68. config.build()
  69. )
  70. iterations = int((args.steps / train_batch_size)) + 1
  71. for i in range(iterations):
  72. result = algo.train()
  73. print(pretty_print(result))
  74. if i % 5 == 0:
  75. print("Saving checkpoint")
  76. checkpoint_dir = algo.save()
  77. print(f"Checkpoint saved in directory {checkpoint_dir}")
  78. def main():
  79. import argparse
  80. args = parse_arguments(argparse)
  81. if args.algorithm == "PPO":
  82. ppo(args)
  83. elif args.algorithm == "DQN":
  84. dqn(args)
  85. if __name__ == '__main__':
  86. main()