You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.0 KiB

  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # ## Example usage of Tempestpy
  4. # In[1]:
  5. from sb3_contrib import MaskablePPO
  6. from sb3_contrib.common.wrappers import ActionMasker
  7. from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
  8. import gymnasium as gym
  9. from minigrid.core.actions import Actions
  10. from minigrid.core.constants import TILE_PIXELS
  11. from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
  12. import tempfile, datetime, shutil
  13. import time
  14. import os
  15. from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image
  16. from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
  17. import os, sys
  18. from copy import deepcopy
  19. from PIL import Image
  20. # In[ ]:
  21. GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
  22. def mask_fn(env: gym.Env):
  23. return env.create_action_mask()
  24. def nomask_fn(env: gym.Env):
  25. return [1.0] * 7
  26. def main():
  27. env = "MiniGrid-LavaGapS6-v0"
  28. # TODO Change the safety specification
  29. formula = "Pmax=? [G !AgentIsOnLava]"
  30. value_for_training = 1.0
  31. shield_comparison = "absolute"
  32. shielding = ShieldingConfig.Training
  33. logger = Logger("/tmp", output_formats=[HumanOutputFormat(sys.stdout)])
  34. env = gym.make(env, render_mode="rgb_array")
  35. image_env = RGBImgObsWrapper(env, TILE_PIXELS)
  36. env = RGBImgObsWrapper(env, 8)
  37. env = ImgObsWrapper(env)
  38. env = MiniWrapper(env)
  39. env.reset()
  40. Image.fromarray(env.render()).show()
  41. input("")
  42. shield_handlers = dict()
  43. if shield_needed(shielding):
  44. for value in [0.0, 1.0]:
  45. shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, "grid.txt", "grid.prism", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)
  46. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  47. shield_handlers[value] = shield_handler
  48. print("Symbolic Description of the Model:")
  49. shield_handlers[1.0].print_symbolic_model()
  50. input("")
  51. if shield_needed(shielding):
  52. for value in [1.0]:
  53. create_shield_overlay_image(image_env, shield_handlers[value].create_shield())
  54. print(f"The shield for shield_value = {value}")
  55. input("")
  56. if shielding == ShieldingConfig.Training:
  57. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  58. env = ActionMasker(env, mask_fn)
  59. print("Training with shield:")
  60. create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())
  61. elif shielding == ShieldingConfig.Disabled:
  62. env = ActionMasker(env, nomask_fn)
  63. else:
  64. assert(False)
  65. model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto")
  66. model.set_logger(logger)
  67. steps = 20_000
  68. model.learn(steps,callback=[InfoCallback()])
  69. if __name__ == '__main__':
  70. main()
  71. # In[ ]: