You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
6.3 KiB

11 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. import ray
  4. from ray.tune import register_env
  5. from ray.tune.experiment.trial import Trial
  6. from ray import tune, air
  7. from ray.rllib.algorithms.ppo import PPOConfig
  8. from ray.tune.logger import UnifiedLogger
  9. from ray.rllib.models import ModelCatalog
  10. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  11. from ray.rllib.algorithms.algorithm import Algorithm
  12. from ray.air import session
  13. from torch_action_mask_model import TorchActionMaskModel
  14. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  15. from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
  16. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  17. from torch.utils.tensorboard import SummaryWriter
  18. from callbacks import MyCallbacks
  19. def shielding_env_creater(config):
  20. name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
  21. framestack = config.get("framestack", 4)
  22. args = config.get("args", None)
  23. args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
  24. args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
  25. shielding = config.get("shielding", False)
  26. shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
  27. grid_to_prism_path=args.grid_to_prism_binary_path,
  28. prism_path=args.prism_path,
  29. formula=args.formula,
  30. shield_value=args.shield_value,
  31. prism_config=args.prism_config)
  32. env = gym.make(name, randomize_start=True)
  33. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
  34. env = OneHotShieldingWrapper(env,
  35. config.vector_index if hasattr(config, "vector_index") else 0,
  36. framestack=framestack
  37. )
  38. return env
  39. def register_minigrid_shielding_env(args):
  40. env_name = "mini-grid-shielding"
  41. register_env(env_name, shielding_env_creater)
  42. ModelCatalog.register_custom_model(
  43. "shielding_model",
  44. TorchActionMaskModel
  45. )
  46. def trial_name_creator(trial : Trial):
  47. return "trial"
  48. def ppo(args):
  49. register_minigrid_shielding_env(args)
  50. logdir = args.log_dir
  51. config = (PPOConfig()
  52. .rollouts(num_rollout_workers=args.workers)
  53. .resources(num_gpus=0)
  54. .environment( env="mini-grid-shielding",
  55. env_config={"name": args.env,
  56. "args": args,
  57. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
  58. },)
  59. .framework("torch")
  60. .callbacks(MyCallbacks)
  61. .evaluation(evaluation_config={
  62. "evaluation_interval": 1,
  63. "evaluation_duration": 10,
  64. "evaluation_num_workers":1,
  65. "env": "mini-grid-shielding",
  66. "env_config": {"name": args.env,
  67. "args": args,
  68. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  69. .rl_module(_enable_rl_module_api = False)
  70. .debugging(logger_config={
  71. "type": UnifiedLogger,
  72. "logdir": logdir
  73. })
  74. .training(_enable_learner_api=False ,model={
  75. "custom_model": "shielding_model"
  76. }))
  77. tuner = tune.Tuner("PPO",
  78. tune_config=tune.TuneConfig(
  79. metric="episode_reward_mean",
  80. mode="max",
  81. num_samples=1,
  82. trial_name_creator=trial_name_creator,
  83. ),
  84. run_config=air.RunConfig(
  85. stop = {"episode_reward_mean": 94,
  86. "timesteps_total": args.steps,},
  87. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True,
  88. num_to_keep=1,
  89. checkpoint_score_attribute="episode_reward_mean",
  90. ),
  91. storage_path=F"{logdir}",
  92. name=test_name(args),
  93. )
  94. ,
  95. param_space=config,)
  96. results = tuner.fit()
  97. best_result = results.get_best_result()
  98. import pprint
  99. metrics_to_print = [
  100. "episode_reward_mean",
  101. "episode_reward_max",
  102. "episode_reward_min",
  103. "episode_len_mean",
  104. ]
  105. pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
  106. # algo = Algorithm.from_checkpoint(best_result.checkpoint)
  107. # eval_log_dir = F"{logdir}-eval"
  108. # writer = SummaryWriter(log_dir=eval_log_dir)
  109. # csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
  110. # for i in range(args.evaluations):
  111. # eval_result = algo.evaluate()
  112. # print(pretty_print(eval_result))
  113. # print(eval_result)
  114. # # logger.on_result(eval_result)
  115. # csv_logger.on_result(eval_result)
  116. # evaluation = eval_result['evaluation']
  117. # epsiode_reward_mean = evaluation['episode_reward_mean']
  118. # episode_len_mean = evaluation['episode_len_mean']
  119. # print(epsiode_reward_mean)
  120. # writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
  121. # writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
  122. def main():
  123. ray.init(num_cpus=3)
  124. import argparse
  125. args = parse_arguments(argparse)
  126. ppo(args)
  127. ray.shutdown()
  128. if __name__ == '__main__':
  129. main()