You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
4.1 KiB

11 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. import ray
  4. from ray.tune import register_env
  5. from ray.tune.experiment.trial import Trial
  6. from ray import tune, air
  7. from ray.rllib.algorithms.ppo import PPOConfig
  8. from ray.tune.logger import UnifiedLogger
  9. from ray.rllib.models import ModelCatalog
  10. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  11. from ray.rllib.algorithms.algorithm import Algorithm
  12. from ray.rllib.algorithms.callbacks import make_multi_callbacks
  13. from ray.air import session
  14. from torch_action_mask_model import TorchActionMaskModel
  15. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
  16. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig, test_name
  17. from torch.utils.tensorboard import SummaryWriter
  18. from callbacks import CustomCallback
  19. def register_minigrid_shielding_env(args):
  20. env_name = "mini-grid-shielding"
  21. register_env(env_name, shielding_env_creater)
  22. ModelCatalog.register_custom_model(
  23. "shielding_model",
  24. TorchActionMaskModel
  25. )
  26. def trial_name_creator(trial : Trial):
  27. return "trial"
  28. def ppo(args):
  29. register_minigrid_shielding_env(args)
  30. logdir = args.log_dir
  31. config = (PPOConfig()
  32. .rollouts(num_rollout_workers=args.workers)
  33. .resources(num_gpus=args.num_gpus)
  34. .environment( env="mini-grid-shielding",
  35. env_config={"name": args.env,
  36. "args": args,
  37. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
  38. },)
  39. .framework("torch")
  40. .callbacks(CustomCallback)
  41. .evaluation(evaluation_config={
  42. "evaluation_interval": 1,
  43. "evaluation_duration": 10,
  44. "evaluation_num_workers":1,
  45. "env": "mini-grid-shielding",
  46. "env_config": {"name": args.env,
  47. "args": args,
  48. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  49. .rl_module(_enable_rl_module_api = False)
  50. .debugging(logger_config={
  51. "type": UnifiedLogger,
  52. "logdir": logdir
  53. })
  54. .training(_enable_learner_api=False ,model={
  55. "custom_model": "shielding_model"
  56. }))
  57. tuner = tune.Tuner("PPO",
  58. tune_config=tune.TuneConfig(
  59. metric="episode_reward_mean",
  60. mode="max",
  61. num_samples=1,
  62. trial_name_creator=trial_name_creator,
  63. ),
  64. run_config=air.RunConfig(
  65. stop = {"episode_reward_mean": 1,
  66. "timesteps_total": args.steps,},
  67. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True,
  68. num_to_keep=1,
  69. checkpoint_score_attribute="episode_reward_mean",
  70. ),
  71. storage_path=F"{logdir}",
  72. name=test_name(args),
  73. ),
  74. param_space=config,)
  75. results = tuner.fit()
  76. best_result = results.get_best_result()
  77. import pprint
  78. metrics_to_print = [
  79. "episode_reward_mean",
  80. "episode_reward_max",
  81. "episode_reward_min",
  82. "episode_len_mean",
  83. ]
  84. pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
  85. def main():
  86. ray.init(num_cpus=3)
  87. import argparse
  88. args = parse_arguments(argparse)
  89. ppo(args)
  90. ray.shutdown()
  91. if __name__ == '__main__':
  92. main()