You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
4.4 KiB

11 months ago
11 months ago
  1. from sb3_contrib import MaskablePPO
  2. from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback
  3. from sb3_contrib.common.wrappers import ActionMasker
  4. from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
  5. import gymnasium as gym
  6. from minigrid.core.actions import Actions
  7. from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
  8. import time
  9. from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation
  10. from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
  11. from stable_baselines3.common.callbacks import EvalCallback
  12. import os, sys
  13. from copy import deepcopy
  14. GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
  15. def mask_fn(env: gym.Env):
  16. return env.create_action_mask()
  17. def nomask_fn(env: gym.Env):
  18. return [1.0] * 7
  19. def main():
  20. args = parse_sb3_arguments()
  21. formula = args.formula
  22. shield_value = args.shield_value
  23. shield_comparison = args.shield_comparison
  24. log_dir = create_log_dir(args)
  25. #new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
  26. new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)])
  27. if shield_needed(args.shielding):
  28. shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup)
  29. env = gym.make(args.env, render_mode="rgb_array")
  30. env = RGBImgObsWrapper(env)
  31. env = ImgObsWrapper(env)
  32. env = MiniWrapper(env)
  33. eval_env = deepcopy(env)
  34. eval_env.disable_random_start()
  35. if args.shielding == ShieldingConfig.Full:
  36. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  37. env = ActionMasker(env, mask_fn)
  38. eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False)
  39. eval_env = ActionMasker(eval_env, mask_fn)
  40. elif args.shielding == ShieldingConfig.Training:
  41. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  42. env = ActionMasker(env, mask_fn)
  43. eval_env = ActionMasker(eval_env, nomask_fn)
  44. elif args.shielding == ShieldingConfig.Evaluation:
  45. env = ActionMasker(env, nomask_fn)
  46. eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False)
  47. eval_env = ActionMasker(eval_env, mask_fn)
  48. elif args.shielding == ShieldingConfig.Disabled:
  49. env = ActionMasker(env, nomask_fn)
  50. eval_env = ActionMasker(eval_env, nomask_fn)
  51. else:
  52. assert(False) # TODO Do something proper
  53. #model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
  54. #model.set_logger(new_logger)
  55. #steps = args.steps
  56. ## Evaluation
  57. #eval_freq=max(500, int(args.steps/30))
  58. #n_eval_episodes=5
  59. #render_freq = eval_freq
  60. #if shielded_evaluation(args.shielding):
  61. # from sb3_contrib.common.maskable.evaluation import evaluate_policy
  62. # evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir,
  63. # log_path=log_dir, eval_freq=eval_freq,
  64. # deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
  65. # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
  66. #else:
  67. # from stable_baselines3.common.evaluation import evaluate_policy
  68. # evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir,
  69. # log_path=log_dir, eval_freq=eval_freq,
  70. # deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
  71. # imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
  72. #model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
  73. #model.save(f"{log_dir}/{expname(args)}")
  74. if __name__ == '__main__':
  75. main()