You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
6.4 KiB

11 months ago
11 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. import ray
  4. from ray.tune import register_env
  5. from ray.tune.experiment.trial import Trial
  6. from ray import tune, air
  7. from ray.rllib.algorithms.ppo import PPOConfig
  8. from ray.tune.logger import UnifiedLogger
  9. from ray.rllib.models import ModelCatalog
  10. from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
  11. from ray.rllib.algorithms.algorithm import Algorithm
  12. from ray.rllib.algorithms.callbacks import make_multi_callbacks
  13. from ray.air import session
  14. from torch_action_mask_model import TorchActionMaskModel
  15. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  16. from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
  17. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  18. from torch.utils.tensorboard import SummaryWriter
  19. from callbacks import MyCallbacks
  20. def shielding_env_creater(config):
  21. name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
  22. framestack = config.get("framestack", 4)
  23. args = config.get("args", None)
  24. args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
  25. args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
  26. shielding = config.get("shielding", False)
  27. shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
  28. grid_to_prism_path=args.grid_to_prism_binary_path,
  29. prism_path=args.prism_path,
  30. formula=args.formula,
  31. shield_value=args.shield_value,
  32. prism_config=args.prism_config,
  33. shield_comparision=args.shield_comparision)
  34. prob_forward = args.prob_forward
  35. prob_direct = args.prob_direct
  36. prob_next = args.prob_next
  37. env = gym.make(name, randomize_start=True,probability_forward=prob_forward, probability_direct_neighbour=prob_direct, probability_next_neighbour=prob_next)
  38. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
  39. env = OneHotShieldingWrapper(env,
  40. config.vector_index if hasattr(config, "vector_index") else 0,
  41. framestack=framestack
  42. )
  43. return env
  44. def register_minigrid_shielding_env(args):
  45. env_name = "mini-grid-shielding"
  46. register_env(env_name, shielding_env_creater)
  47. ModelCatalog.register_custom_model(
  48. "shielding_model",
  49. TorchActionMaskModel
  50. )
  51. def trial_name_creator(trial : Trial):
  52. return "trial"
  53. def ppo(args):
  54. register_minigrid_shielding_env(args)
  55. logdir = args.log_dir
  56. config = (PPOConfig()
  57. .rollouts(num_rollout_workers=args.workers)
  58. .resources(num_gpus=args.num_gpus)
  59. .environment( env="mini-grid-shielding",
  60. env_config={"name": args.env,
  61. "args": args,
  62. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training,
  63. },)
  64. .framework("torch")
  65. .callbacks(MyCallbacks)
  66. .evaluation(evaluation_config={
  67. "evaluation_interval": 1,
  68. "evaluation_duration": 10,
  69. "evaluation_num_workers":1,
  70. "env": "mini-grid-shielding",
  71. "env_config": {"name": args.env,
  72. "args": args,
  73. "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
  74. .rl_module(_enable_rl_module_api = False)
  75. .debugging(logger_config={
  76. "type": UnifiedLogger,
  77. "logdir": logdir
  78. })
  79. .training(_enable_learner_api=False ,model={
  80. "custom_model": "shielding_model"
  81. }))
  82. tuner = tune.Tuner("PPO",
  83. tune_config=tune.TuneConfig(
  84. metric="episode_reward_mean",
  85. mode="max",
  86. num_samples=1,
  87. trial_name_creator=trial_name_creator,
  88. ),
  89. run_config=air.RunConfig(
  90. stop = {"episode_reward_mean": 94,
  91. "timesteps_total": args.steps,},
  92. checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True,
  93. num_to_keep=1,
  94. checkpoint_score_attribute="episode_reward_mean",
  95. ),
  96. storage_path=F"{logdir}",
  97. name=test_name(args),
  98. ),
  99. param_space=config,)
  100. results = tuner.fit()
  101. best_result = results.get_best_result()
  102. import pprint
  103. metrics_to_print = [
  104. "episode_reward_mean",
  105. "episode_reward_max",
  106. "episode_reward_min",
  107. "episode_len_mean",
  108. ]
  109. pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
  110. # algo = Algorithm.from_checkpoint(best_result.checkpoint)
  111. # eval_log_dir = F"{logdir}-eval"
  112. # writer = SummaryWriter(log_dir=eval_log_dir)
  113. # csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
  114. # for i in range(args.evaluations):
  115. # eval_result = algo.evaluate()
  116. # print(pretty_print(eval_result))
  117. # print(eval_result)
  118. # # logger.on_result(eval_result)
  119. # csv_logger.on_result(eval_result)
  120. # evaluation = eval_result['evaluation']
  121. # epsiode_reward_mean = evaluation['episode_reward_mean']
  122. # episode_len_mean = evaluation['episode_len_mean']
  123. # print(epsiode_reward_mean)
  124. # writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
  125. # writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
  126. def main():
  127. ray.init(num_cpus=3)
  128. import argparse
  129. args = parse_arguments(argparse)
  130. ppo(args)
  131. ray.shutdown()
  132. if __name__ == '__main__':
  133. main()