You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.1 KiB

12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
12 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from torch_action_mask_model import TorchActionMaskModel
  9. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  10. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  11. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  12. from callbacks import MyCallbacks
  13. from ray.tune.logger import TBXLogger
  14. def shielding_env_creater(config):
  15. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  16. framestack = config.get("framestack", 4)
  17. args = config.get("args", None)
  18. args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
  19. args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
  20. prob_forward = args.prob_forward
  21. prob_direct = args.prob_direct
  22. prob_next = args.prob_next
  23. shield_creator = MiniGridShieldHandler(args.grid_path,
  24. args.grid_to_prism_binary_path,
  25. args.prism_path,
  26. args.formula,
  27. args.shield_value,
  28. args.prism_config,
  29. shield_comparision=args.shield_comparision)
  30. env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
  31. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
  32. shield_query_creator=create_shield_query,
  33. mask_actions=args.shielding != ShieldingConfig.Disabled,
  34. create_shield_at_reset=args.shield_creation_at_reset)
  35. # env = minigrid.wrappers.ImgObsWrapper(env)
  36. # env = ImgObsWrapper(env)
  37. env = OneHotShieldingWrapper(env,
  38. config.vector_index if hasattr(config, "vector_index") else 0,
  39. framestack=framestack
  40. )
  41. return env
  42. def register_minigrid_shielding_env(args):
  43. env_name = "mini-grid-shielding"
  44. register_env(env_name, shielding_env_creater)
  45. ModelCatalog.register_custom_model(
  46. "shielding_model",
  47. TorchActionMaskModel
  48. )
  49. def ppo(args):
  50. train_batch_size = 4000
  51. register_minigrid_shielding_env(args)
  52. config = (PPOConfig()
  53. .rollouts(num_rollout_workers=args.workers)
  54. .resources(num_gpus=0)
  55. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  56. .framework("torch")
  57. .callbacks(MyCallbacks)
  58. .rl_module(_enable_rl_module_api = False)
  59. .debugging(logger_config={
  60. "type": TBXLogger,
  61. "logdir": create_log_dir(args)
  62. })
  63. # .exploration(exploration_config={"exploration_fraction": 0.1})
  64. .training(_enable_learner_api=False ,
  65. model={"custom_model": "shielding_model"},
  66. train_batch_size=train_batch_size))
  67. # config.entropy_coeff = 0.05
  68. algo =(
  69. config.build()
  70. )
  71. iterations = int((args.steps / train_batch_size)) + 1
  72. for i in range(iterations):
  73. result = algo.train()
  74. print(pretty_print(result))
  75. if i % 5 == 0:
  76. checkpoint_dir = algo.save()
  77. print(f"Checkpoint saved in directory {checkpoint_dir}")
  78. algo.save()
  79. def dqn(args):
  80. train_batch_size = 4000
  81. register_minigrid_shielding_env(args)
  82. config = DQNConfig()
  83. config = config.resources(num_gpus=0)
  84. config = config.rollouts(num_rollout_workers=args.workers)
  85. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  86. config = config.framework("torch")
  87. config = config.callbacks(MyCallbacks)
  88. config = config.rl_module(_enable_rl_module_api = False)
  89. config = config.debugging(logger_config={
  90. "type": TBXLogger,
  91. "logdir": create_log_dir(args)
  92. })
  93. config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={
  94. "custom_model": "shielding_model"
  95. })
  96. algo = (
  97. config.build()
  98. )
  99. iterations = int((args.steps / train_batch_size)) + 1
  100. for i in range(iterations):
  101. result = algo.train()
  102. print(pretty_print(result))
  103. if i % 5 == 0:
  104. print("Saving checkpoint")
  105. checkpoint_dir = algo.save()
  106. print(f"Checkpoint saved in directory {checkpoint_dir}")
  107. def main():
  108. import argparse
  109. args = parse_arguments(argparse)
  110. if args.algorithm == "PPO":
  111. ppo(args)
  112. elif args.algorithm == "DQN":
  113. dqn(args)
  114. if __name__ == '__main__':
  115. main()