You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

131 lines
4.3 KiB

  1. import gymnasium as gym
  2. import minigrid
  3. from ray import tune, air
  4. from ray.tune import register_env
  5. from ray.rllib.algorithms.algorithm import Algorithm
  6. from ray.rllib.algorithms.ppo import PPOConfig
  7. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  8. from ray.tune.logger import pretty_print
  9. from ray.rllib.models import ModelCatalog
  10. from torch_action_mask_model import TorchActionMaskModel
  11. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  12. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  13. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  14. from callbacks import MyCallbacks
  15. from torch.utils.tensorboard import SummaryWriter
  16. from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger
  17. def shielding_env_creater(config):
  18. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  19. framestack = config.get("framestack", 4)
  20. args = config.get("args", None)
  21. args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
  22. args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
  23. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  24. env = gym.make(name)
  25. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
  26. # env = minigrid.wrappers.ImgObsWrapper(env)
  27. # env = ImgObsWrapper(env)
  28. env = OneHotShieldingWrapper(env,
  29. config.vector_index if hasattr(config, "vector_index") else 0,
  30. framestack=framestack
  31. )
  32. return env
  33. def register_minigrid_shielding_env(args):
  34. env_name = "mini-grid-shielding"
  35. register_env(env_name, shielding_env_creater)
  36. ModelCatalog.register_custom_model(
  37. "shielding_model",
  38. TorchActionMaskModel
  39. )
  40. def ppo(args):
  41. register_minigrid_shielding_env(args)
  42. config = (PPOConfig()
  43. .rollouts(num_rollout_workers=args.workers)
  44. .resources(num_gpus=0)
  45. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  46. .framework("torch")
  47. .callbacks(MyCallbacks)
  48. .rl_module(_enable_rl_module_api = False)
  49. .debugging(logger_config={
  50. "type": TBXLogger,
  51. "logdir": create_log_dir(args)
  52. })
  53. .training(_enable_learner_api=False ,model={
  54. "custom_model": "shielding_model"
  55. }))
  56. return config
  57. def dqn(args):
  58. register_minigrid_shielding_env(args)
  59. config = DQNConfig()
  60. config = config.resources(num_gpus=0)
  61. config = config.rollouts(num_rollout_workers=args.workers)
  62. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  63. config = config.framework("torch")
  64. config = config.callbacks(MyCallbacks)
  65. config = config.rl_module(_enable_rl_module_api = False)
  66. config = config.debugging(logger_config={
  67. "type": TBXLogger,
  68. "logdir": create_log_dir(args)
  69. })
  70. config = config.training(hiddens=[], dueling=False, model={
  71. "custom_model": "shielding_model"
  72. })
  73. return config
  74. def main():
  75. import argparse
  76. args = parse_arguments(argparse)
  77. if args.algorithm == "PPO":
  78. config = ppo(args)
  79. elif args.algorithm == "DQN":
  80. config = dqn(args)
  81. logdir = create_log_dir(args)
  82. tuner = tune.Tuner(args.algorithm,
  83. tune_config=tune.TuneConfig(
  84. metric="episode_reward_mean",
  85. mode="max",
  86. num_samples=1,
  87. ),
  88. run_config=air.RunConfig(
  89. stop = {"episode_reward_mean": 94,
  90. "timesteps_total": 12000,},
  91. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
  92. storage_path=F"{logdir}"
  93. ),
  94. param_space=config,
  95. )
  96. tuner.fit()
  97. if __name__ == '__main__':
  98. main()