You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.3 KiB

1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
1 year ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from torch_action_mask_model import TorchActionMaskModel
  9. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  10. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  11. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  12. from callbacks import MyCallbacks
  13. from ray.tune.logger import TBXLogger
  14. def shielding_env_creater(config):
  15. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  16. framestack = config.get("framestack", 4)
  17. args = config.get("args", None)
  18. args.grid_path = F"{args.grid_path}_{config.worker_index}_{args.prism_config}.txt"
  19. args.prism_path = F"{args.prism_path}_{config.worker_index}_{args.prism_config}.prism"
  20. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config)
  21. env = gym.make(name)
  22. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
  23. shield_query_creator=create_shield_query,
  24. mask_actions=args.shielding != ShieldingConfig.Disabled,
  25. create_shield_at_reset=args.shield_creation_at_reset)
  26. # env = minigrid.wrappers.ImgObsWrapper(env)
  27. # env = ImgObsWrapper(env)
  28. env = OneHotShieldingWrapper(env,
  29. config.vector_index if hasattr(config, "vector_index") else 0,
  30. framestack=framestack
  31. )
  32. return env
  33. def register_minigrid_shielding_env(args):
  34. env_name = "mini-grid-shielding"
  35. register_env(env_name, shielding_env_creater)
  36. ModelCatalog.register_custom_model(
  37. "shielding_model",
  38. TorchActionMaskModel
  39. )
  40. def ppo(args):
  41. register_minigrid_shielding_env(args)
  42. config = (PPOConfig()
  43. .rollouts(num_rollout_workers=args.workers)
  44. .resources(num_gpus=0)
  45. .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  46. .framework("torch")
  47. .callbacks(MyCallbacks)
  48. .rl_module(_enable_rl_module_api = False)
  49. .debugging(logger_config={
  50. "type": TBXLogger,
  51. "logdir": create_log_dir(args)
  52. })
  53. # .exploration(exploration_config={"exploration_fraction": 0.1})
  54. .training(_enable_learner_api=False ,model={
  55. "custom_model": "shielding_model"
  56. }))
  57. # config.entropy_coeff = 0.05
  58. algo =(
  59. config.build()
  60. )
  61. for i in range(args.evaluations):
  62. result = algo.train()
  63. print(pretty_print(result))
  64. if i % 5 == 0:
  65. checkpoint_dir = algo.save()
  66. print(f"Checkpoint saved in directory {checkpoint_dir}")
  67. algo.save()
  68. def dqn(args):
  69. register_minigrid_shielding_env(args)
  70. config = DQNConfig()
  71. config = config.resources(num_gpus=0)
  72. config = config.rollouts(num_rollout_workers=args.workers)
  73. config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
  74. config = config.framework("torch")
  75. config = config.callbacks(MyCallbacks)
  76. config = config.rl_module(_enable_rl_module_api = False)
  77. config = config.debugging(logger_config={
  78. "type": TBXLogger,
  79. "logdir": create_log_dir(args)
  80. })
  81. config = config.training(hiddens=[], dueling=False, model={
  82. "custom_model": "shielding_model"
  83. })
  84. algo = (
  85. config.build()
  86. )
  87. for i in range(args.evaluations):
  88. result = algo.train()
  89. print(pretty_print(result))
  90. if i % 5 == 0:
  91. print("Saving checkpoint")
  92. checkpoint_dir = algo.save()
  93. print(f"Checkpoint saved in directory {checkpoint_dir}")
  94. def main():
  95. import argparse
  96. args = parse_arguments(argparse)
  97. if args.algorithm == "PPO":
  98. ppo(args)
  99. elif args.algorithm == "DQN":
  100. dqn(args)
  101. if __name__ == '__main__':
  102. main()