You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.0 KiB

  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # ## Example usage of Tempestpy
  4. # In[1]:
  5. from sb3_contrib import MaskablePPO
  6. from sb3_contrib.common.wrappers import ActionMasker
  7. from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
  8. import gymnasium as gym
  9. from minigrid.core.actions import Actions
  10. from minigrid.core.constants import TILE_PIXELS
  11. from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
  12. import tempfile, datetime, shutil
  13. import time
  14. import os
  15. from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image
  16. from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
  17. import os, sys
  18. from copy import deepcopy
  19. from PIL import Image
  20. # In[3]:
  21. GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
  22. def mask_fn(env: gym.Env):
  23. return env.create_action_mask()
  24. def nomask_fn(env: gym.Env):
  25. return [1.0] * 7
  26. def main():
  27. #env = "MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0"
  28. env = "MiniGrid-WindyCity2-v0"
  29. formula = "Pmax=? [G ! AgentIsOnLava]"
  30. value_for_training = 0.99
  31. shield_comparison = "absolute"
  32. shielding = ShieldingConfig.Training
  33. logger = Logger("/tmp", output_formats=[HumanOutputFormat(sys.stdout)])
  34. env = gym.make(env, render_mode="rgb_array")
  35. image_env = RGBImgObsWrapper(env, TILE_PIXELS)
  36. env = RGBImgObsWrapper(env, 8)
  37. env = ImgObsWrapper(env)
  38. env = MiniWrapper(env)
  39. env.reset()
  40. Image.fromarray(env.render()).show()
  41. shield_handlers = dict()
  42. if shield_needed(shielding):
  43. for value in [0.9, 0.95, 0.99, 1.0]:
  44. shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, "grid.txt", "grid.prism", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)
  45. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  46. shield_handlers[value] = shield_handler
  47. if shield_needed(shielding):
  48. for value in [0.9, 0.95, 0.99, 1.0]:
  49. create_shield_overlay_image(image_env, shield_handlers[value].create_shield())
  50. print(f"The shield for shield_value = {value}")
  51. if shielding == ShieldingConfig.Training:
  52. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)
  53. env = ActionMasker(env, mask_fn)
  54. print("Training with shield:")
  55. create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())
  56. elif shielding == ShieldingConfig.Disabled:
  57. env = ActionMasker(env, nomask_fn)
  58. else:
  59. assert(False)
  60. model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto")
  61. model.set_logger(logger)
  62. steps = 20_000
  63. assert(False)
  64. model.learn(steps,callback=[InfoCallback()])
  65. if __name__ == '__main__':
  66. print("Starting the training")
  67. main()
  68. # In[ ]: