The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

100 lines
3.2 KiB

4 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  6. from ray.rllib.models import ModelCatalog
  7. from torch_action_mask_model import TorchActionMaskModel
  8. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper, shielding_env_creater
  9. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  10. from callbacks import CustomCallback
  11. from torch.utils.tensorboard import SummaryWriter
  12. def register_minigrid_shielding_env(args):
  13. env_name = "mini-grid-shielding"
  14. register_env(env_name, shielding_env_creater)
  15. ModelCatalog.register_custom_model(
  16. "shielding_model",
  17. TorchActionMaskModel
  18. )
  19. def ppo(args):
  20. register_minigrid_shielding_env(args)
  21. train_batch_size = 4000
  22. config = (PPOConfig()
  23. .rollouts(num_rollout_workers=args.workers)
  24. .resources(num_gpus=0)
  25. .environment( env="mini-grid-shielding",
  26. env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  27. .framework("torch")
  28. .callbacks(CustomCallback)
  29. .evaluation(evaluation_config={
  30. "evaluation_interval": 1,
  31. "evaluation_duration": 10,
  32. "evaluation_num_workers":1,
  33. "env": "mini-grid-shielding",
  34. "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  35. .rl_module(_enable_rl_module_api = False)
  36. .debugging(logger_config={
  37. "type": UnifiedLogger,
  38. "logdir": create_log_dir(args)
  39. })
  40. .training(_enable_learner_api=False ,model={
  41. "custom_model": "shielding_model"
  42. }, train_batch_size=train_batch_size))
  43. algo =(
  44. config.build()
  45. )
  46. iterations = int((args.steps / train_batch_size)) + 1
  47. for i in range(iterations):
  48. algo.train()
  49. if i % 5 == 0:
  50. algo.save()
  51. eval_log_dir = F"{create_log_dir(args)}-eval"
  52. writer = SummaryWriter(log_dir=eval_log_dir)
  53. csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
  54. for i in range(evaluations):
  55. eval_result = algo.evaluate()
  56. print(pretty_print(eval_result))
  57. print(eval_result)
  58. # logger.on_result(eval_result)
  59. csv_logger.on_result(eval_result)
  60. evaluation = eval_result['evaluation']
  61. epsiode_reward_mean = evaluation['episode_reward_mean']
  62. episode_len_mean = evaluation['episode_len_mean']
  63. print(epsiode_reward_mean)
  64. writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
  65. writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
  66. writer.close()
  67. def main():
  68. import argparse
  69. args = parse_arguments(argparse)
  70. ppo(args)
  71. if __name__ == '__main__':
  72. main()