You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
4.0 KiB

  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  6. from ray.rllib.models import ModelCatalog
  7. from torch_action_mask_model import TorchActionMaskModel
  8. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  9. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  10. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  11. from callbacks import MyCallbacks
  12. from torch.utils.tensorboard import SummaryWriter
  13. def shielding_env_creater(config):
  14. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  15. framestack = config.get("framestack", 4)
  16. args = config.get("args", None)
  17. args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
  18. args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
  19. shielding = config.get("shielding", False)
  20. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  21. env = gym.make(name)
  22. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
  23. env = OneHotShieldingWrapper(env,
  24. config.vector_index if hasattr(config, "vector_index") else 0,
  25. framestack=framestack
  26. )
  27. return env
  28. def register_minigrid_shielding_env(args):
  29. env_name = "mini-grid-shielding"
  30. register_env(env_name, shielding_env_creater)
  31. ModelCatalog.register_custom_model(
  32. "shielding_model",
  33. TorchActionMaskModel
  34. )
  35. def ppo(args):
  36. register_minigrid_shielding_env(args)
  37. config = (PPOConfig()
  38. .rollouts(num_rollout_workers=args.workers)
  39. .resources(num_gpus=0)
  40. .environment( env="mini-grid-shielding",
  41. env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
  42. .framework("torch")
  43. .callbacks(MyCallbacks)
  44. .evaluation(evaluation_config={
  45. "evaluation_interval": 1,
  46. "evaluation_duration": 10,
  47. "evaluation_num_workers":1,
  48. "env": "mini-grid-shielding",
  49. "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  50. .rl_module(_enable_rl_module_api = False)
  51. .debugging(logger_config={
  52. "type": UnifiedLogger,
  53. "logdir": create_log_dir(args)
  54. })
  55. .training(_enable_learner_api=False ,model={
  56. "custom_model": "shielding_model"
  57. }))
  58. algo =(
  59. config.build()
  60. )
  61. evaluations = args.evaluations
  62. for i in range(evaluations):
  63. algo.train()
  64. if i % 5 == 0:
  65. algo.save()
  66. eval_log_dir = F"{create_log_dir(args)}-eval"
  67. writer = SummaryWriter(log_dir=eval_log_dir)
  68. csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
  69. for i in range(evaluations):
  70. eval_result = algo.evaluate()
  71. print(pretty_print(eval_result))
  72. print(eval_result)
  73. # logger.on_result(eval_result)
  74. csv_logger.on_result(eval_result)
  75. evaluation = eval_result['evaluation']
  76. epsiode_reward_mean = evaluation['episode_reward_mean']
  77. episode_len_mean = evaluation['episode_len_mean']
  78. print(epsiode_reward_mean)
  79. writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
  80. writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
  81. writer.close()
  82. def main():
  83. import argparse
  84. args = parse_arguments(argparse)
  85. ppo(args)
  86. if __name__ == '__main__':
  87. main()