You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.0 KiB

  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # ## Example usage of Tempestpy
  4. # In[1]:
  5. from sb3_contrib import MaskablePPO
  6. from sb3_contrib.common.wrappers import ActionMasker
  7. from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat
  8. import gymnasium as gym
  9. from minigrid.core.actions import Actions
  10. from minigrid.core.constants import TILE_PIXELS
  11. from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
  12. import tempfile, datetime, shutil
  13. import time
  14. import os
  15. from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image
  16. from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback
  17. import os, sys
  18. from copy import deepcopy
  19. from PIL import Image
  20. # In[3]:
  21. GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
  22. def mask_fn(env: gym.Env):
  23. return env.create_action_mask()
  24. def nomask_fn(env: gym.Env):
  25. return [1.0] * 7
  26. def main():
  27. env = "MiniGrid-LavaFaultyS15-1-v0"
  28. formula = "Pmax=? [G ! AgentIsOnLava]"
  29. value_for_training = 0.0
  30. shield_comparison = "absolute"
  31. shielding = ShieldingConfig.Training
  32. logger = Logger("/tmp", output_formats=[HumanOutputFormat(sys.stdout)])
  33. env = gym.make(env, render_mode="rgb_array")
  34. image_env = RGBImgObsWrapper(env, TILE_PIXELS)
  35. env = RGBImgObsWrapper(env, 8)
  36. env = ImgObsWrapper(env)
  37. env = MiniWrapper(env)
  38. env.reset()
  39. Image.fromarray(env.render()).show()
  40. shield_values = [0.0, 0.9, 0.99, 0.999, 1.0]
  41. shield_handlers = dict()
  42. if shield_needed(shielding):
  43. for value in shield_values:
  44. shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, "grid.txt", "grid.prism", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)
  45. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
  46. shield_handlers[value] = shield_handler
  47. if shield_needed(shielding):
  48. for value in shield_values:
  49. create_shield_overlay_image(image_env, shield_handlers[value].create_shield())
  50. print(f"The shield for shield_value = {value}")
  51. if shielding == ShieldingConfig.Training:
  52. env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)
  53. env = ActionMasker(env, mask_fn)
  54. print("Training with shield:")
  55. create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())
  56. elif shielding == ShieldingConfig.Disabled:
  57. env = ActionMasker(env, nomask_fn)
  58. else:
  59. assert(False)
  60. model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto")
  61. model.set_logger(logger)
  62. steps = 20_000
  63. assert(False)
  64. model.learn(steps,callback=[InfoCallback()])
  65. if __name__ == '__main__':
  66. print("Starting the training")
  67. main()
  68. # In[ ]: