The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

147 lines
4.5 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example usage of Tempestpy"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": null,
  13. "metadata": {
  14. "vscode": {
  15. "languageId": "plaintext"
  16. }
  17. },
  18. "outputs": [],
  19. "source": [
  20. "from sb3_contrib import MaskablePPO\n",
  21. "from sb3_contrib.common.wrappers import ActionMasker\n",
  22. "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
  23. "\n",
  24. "import gymnasium as gym\n",
  25. "\n",
  26. "from minigrid.core.actions import Actions\n",
  27. "from minigrid.core.constants import TILE_PIXELS\n",
  28. "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
  29. "\n",
  30. "import tempfile, datetime, shutil\n",
  31. "\n",
  32. "import time\n",
  33. "import os\n",
  34. "\n",
  35. "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
  36. "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
  37. "\n",
  38. "import os, sys\n",
  39. "from copy import deepcopy\n",
  40. "\n",
  41. "from PIL import Image"
  42. ]
  43. },
  44. {
  45. "cell_type": "code",
  46. "execution_count": null,
  47. "metadata": {
  48. "vscode": {
  49. "languageId": "plaintext"
  50. }
  51. },
  52. "outputs": [],
  53. "source": [
  54. "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
  55. "\n",
  56. "def mask_fn(env: gym.Env):\n",
  57. " return env.create_action_mask()\n",
  58. "\n",
  59. "def nomask_fn(env: gym.Env):\n",
  60. " return [1.0] * 7\n",
  61. "\n",
  62. "def main():\n",
  63. " env = \"MiniGrid-LavaGapS6-v0\"\n",
  64. "\n",
  65. " # TODO Change the safety specification\n",
  66. " formula = \"Pmax=? [G !AgentIsOnLava]\"\n",
  67. " value_for_training = 1.0\n",
  68. " shield_comparison = \"absolute\"\n",
  69. " shielding = ShieldingConfig.Training\n",
  70. " \n",
  71. " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
  72. " \n",
  73. "\n",
  74. " env = gym.make(env, render_mode=\"rgb_array\")\n",
  75. " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
  76. " env = RGBImgObsWrapper(env, 8)\n",
  77. " env = ImgObsWrapper(env)\n",
  78. " env = MiniWrapper(env)\n",
  79. "\n",
  80. " \n",
  81. " env.reset()\n",
  82. " Image.fromarray(env.render()).show()\n",
  83. " \n",
  84. " shield_handlers = dict()\n",
  85. " if shield_needed(shielding):\n",
  86. " for value in [0.0, 1.0]:\n",
  87. " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
  88. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  89. " create_shield_overlay_image(image_env, shield_handler.create_shield())\n",
  90. " print(f\"The shield for shield_value = {value}\")\n",
  91. "\n",
  92. " shield_handlers[value] = shield_handler\n",
  93. "\n",
  94. "\n",
  95. " if shielding == ShieldingConfig.Training:\n",
  96. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  97. " env = ActionMasker(env, mask_fn)\n",
  98. " print(\"Training with shield:\")\n",
  99. " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
  100. " elif shielding == ShieldingConfig.Disabled:\n",
  101. " env = ActionMasker(env, nomask_fn)\n",
  102. " else:\n",
  103. " assert(False) \n",
  104. " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
  105. " model.set_logger(logger)\n",
  106. " steps = 20_000\n",
  107. "\n",
  108. " \n",
  109. " model.learn(steps,callback=[InfoCallback()])\n",
  110. "\n",
  111. "\n",
  112. "\n",
  113. "if __name__ == '__main__':\n",
  114. " print(\"Starting the training\")\n",
  115. " main()"
  116. ]
  117. },
  118. {
  119. "cell_type": "code",
  120. "execution_count": null,
  121. "metadata": {},
  122. "outputs": [],
  123. "source": []
  124. }
  125. ],
  126. "metadata": {
  127. "kernelspec": {
  128. "display_name": "Python 3 (ipykernel)",
  129. "language": "python",
  130. "name": "python3"
  131. },
  132. "language_info": {
  133. "codemirror_mode": {
  134. "name": "ipython",
  135. "version": 3
  136. },
  137. "file_extension": ".py",
  138. "mimetype": "text/x-python",
  139. "name": "python",
  140. "nbconvert_exporter": "python",
  141. "pygments_lexer": "ipython3",
  142. "version": "3.10.12"
  143. }
  144. },
  145. "nbformat": 4,
  146. "nbformat_minor": 4
  147. }