You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.8 KiB

  1. from sb3_contrib import MaskablePPO
  2. from sb3_contrib.common.maskable.evaluation import evaluate_policy
  3. from sb3_contrib.common.wrappers import ActionMasker
  4. from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat
  5. import gymnasium as gym
  6. from minigrid.core.actions import Actions
  7. from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
  8. import time
  9. from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname
  10. from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
  11. from stable_baselines3.common.callbacks import EvalCallback
  12. import os
  13. GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
  14. def mask_fn(env: gym.Env):
  15. return env.create_action_mask()
  16. def nomask_fn(env: gym.Env):
  17. return [1.0] * 7
  18. def main():
  19. args = parse_sb3_arguments()
  20. formula = args.formula
  21. shield_value = args.shield_value
  22. shield_comparison = args.shield_comparison
  23. log_dir = create_log_dir(args)
  24. new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
  25. env = gym.make(args.env, render_mode="rgb_array")
  26. env = RGBImgObsWrapper(env)
  27. env = ImgObsWrapper(env)
  28. env = MiniWrapper(env)
  29. if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training:
  30. shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup)
  31. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  32. env = ActionMasker(env, mask_fn)
  33. else:
  34. env = ActionMasker(env, nomask_fn)
  35. model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
  36. model.set_logger(new_logger)
  37. evalCallback = EvalCallback(env, best_model_save_path=log_dir,
  38. log_path=log_dir, eval_freq=max(500, int(args.steps/30)),
  39. deterministic=True, render=False)
  40. steps = args.steps
  41. model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()])
  42. #vec_env = model.get_env()
  43. #obs = vec_env.reset()
  44. #terminated = truncated = False
  45. #while not terminated and not truncated:
  46. # action_masks = None
  47. # action, _states = model.predict(obs, action_masks=action_masks)
  48. # print(action)
  49. # obs, reward, terminated, truncated, info = env.step(action)
  50. # # action, _states = model.predict(obs, deterministic=True)
  51. # # obs, rewards, dones, info = vec_env.step(action)
  52. # vec_env.render("human")
  53. # time.sleep(0.2)
  54. if __name__ == '__main__':
  55. main()