You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

126 lines
4.1 KiB

  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  6. from ray.rllib.models import ModelCatalog
  7. from torch_action_mask_model import TorchActionMaskModel
  8. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  9. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  10. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  11. from callbacks import MyCallbacks
  12. from torch.utils.tensorboard import SummaryWriter
  13. def shielding_env_creater(config):
  14. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  15. framestack = config.get("framestack", 4)
  16. args = config.get("args", None)
  17. args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
  18. args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
  19. shielding = config.get("shielding", False)
  20. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  21. env = gym.make(name)
  22. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
  23. env = OneHotShieldingWrapper(env,
  24. config.vector_index if hasattr(config, "vector_index") else 0,
  25. framestack=framestack
  26. )
  27. return env
  28. def register_minigrid_shielding_env(args):
  29. env_name = "mini-grid-shielding"
  30. register_env(env_name, shielding_env_creater)
  31. ModelCatalog.register_custom_model(
  32. "shielding_model",
  33. TorchActionMaskModel
  34. )
  35. def ppo(args):
  36. register_minigrid_shielding_env(args)
  37. train_batch_size = 4000
  38. config = (PPOConfig()
  39. .rollouts(num_rollout_workers=args.workers)
  40. .resources(num_gpus=0)
  41. .environment( env="mini-grid-shielding",
  42. env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  43. .framework("torch")
  44. .callbacks(MyCallbacks)
  45. .evaluation(evaluation_config={
  46. "evaluation_interval": 1,
  47. "evaluation_duration": 10,
  48. "evaluation_num_workers":1,
  49. "env": "mini-grid-shielding",
  50. "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  51. .rl_module(_enable_rl_module_api = False)
  52. .debugging(logger_config={
  53. "type": UnifiedLogger,
  54. "logdir": create_log_dir(args)
  55. })
  56. .training(_enable_learner_api=False ,model={
  57. "custom_model": "shielding_model"
  58. }, train_batch_size=train_batch_size))
  59. algo =(
  60. config.build()
  61. )
  62. iterations = int((args.steps / train_batch_size)) + 1
  63. for i in range(iterations):
  64. algo.train()
  65. if i % 5 == 0:
  66. algo.save()
  67. eval_log_dir = F"{create_log_dir(args)}-eval"
  68. writer = SummaryWriter(log_dir=eval_log_dir)
  69. csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
  70. for i in range(evaluations):
  71. eval_result = algo.evaluate()
  72. print(pretty_print(eval_result))
  73. print(eval_result)
  74. # logger.on_result(eval_result)
  75. csv_logger.on_result(eval_result)
  76. evaluation = eval_result['evaluation']
  77. epsiode_reward_mean = evaluation['episode_reward_mean']
  78. episode_len_mean = evaluation['episode_len_mean']
  79. print(epsiode_reward_mean)
  80. writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
  81. writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
  82. writer.close()
  83. def main():
  84. import argparse
  85. args = parse_arguments(argparse)
  86. ppo(args)
  87. if __name__ == '__main__':
  88. main()