The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

143 lines
4.4 KiB

2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example usage of Tempestpy"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": null,
  13. "metadata": {
  14. "vscode": {
  15. "languageId": "plaintext"
  16. }
  17. },
  18. "outputs": [],
  19. "source": [
  20. "from sb3_contrib import MaskablePPO\n",
  21. "from sb3_contrib.common.wrappers import ActionMasker\n",
  22. "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
  23. "\n",
  24. "import gymnasium as gym\n",
  25. "\n",
  26. "from minigrid.core.actions import Actions\n",
  27. "from minigrid.core.constants import TILE_PIXELS\n",
  28. "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
  29. "\n",
  30. "import tempfile, datetime, shutil\n",
  31. "\n",
  32. "import time\n",
  33. "import os\n",
  34. "\n",
  35. "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
  36. "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
  37. "\n",
  38. "import os, sys\n",
  39. "from copy import deepcopy\n",
  40. "\n",
  41. "from PIL import Image"
  42. ]
  43. },
  44. {
  45. "cell_type": "code",
  46. "execution_count": null,
  47. "metadata": {
  48. "vscode": {
  49. "languageId": "plaintext"
  50. }
  51. },
  52. "outputs": [],
  53. "source": [
  54. "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
  55. "\n",
  56. "def mask_fn(env: gym.Env):\n",
  57. " return env.create_action_mask()\n",
  58. "\n",
  59. "def nomask_fn(env: gym.Env):\n",
  60. " return [1.0] * 7\n",
  61. "\n",
  62. "def main():\n",
  63. " env = \"MiniGrid-LavaFaultyS15-1-v0\"\n",
  64. " \n",
  65. " formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
  66. " value_for_training = 0.0\n",
  67. " shield_comparison = \"absolute\"\n",
  68. " shielding = ShieldingConfig.Training\n",
  69. " \n",
  70. " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
  71. " \n",
  72. " env = gym.make(env, render_mode=\"rgb_array\")\n",
  73. " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
  74. " env = RGBImgObsWrapper(env, 8)\n",
  75. " env = ImgObsWrapper(env)\n",
  76. " env = MiniWrapper(env)\n",
  77. "\n",
  78. " \n",
  79. " env.reset()\n",
  80. " Image.fromarray(env.render()).show()\n",
  81. " \n",
  82. " shield_handlers = dict()\n",
  83. " if shield_needed(shielding):\n",
  84. " for value in [0.0, 1.0]: \n",
  85. " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)\n",
  86. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  87. " create_shield_overlay_image(image_env, shield_handler.create_shield())\n",
  88. " shield_handlers[value] = shield_handler\n",
  89. "\n",
  90. "\n",
  91. " if shielding == ShieldingConfig.Training:\n",
  92. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
  93. " env = ActionMasker(env, mask_fn)\n",
  94. " print(\"Training with shield:\")\n",
  95. " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
  96. " elif shielding == ShieldingConfig.Disabled:\n",
  97. " env = ActionMasker(env, nomask_fn)\n",
  98. " else:\n",
  99. " assert(False) \n",
  100. " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
  101. " model.set_logger(logger)\n",
  102. " steps = 20_000\n",
  103. "\n",
  104. " assert(False)\n",
  105. " model.learn(steps,callback=[InfoCallback()])\n",
  106. "\n",
  107. "\n",
  108. "\n",
  109. "if __name__ == '__main__':\n",
  110. " print(\"Starting the training\")\n",
  111. " main()"
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "execution_count": null,
  117. "metadata": {},
  118. "outputs": [],
  119. "source": []
  120. }
  121. ],
  122. "metadata": {
  123. "kernelspec": {
  124. "display_name": "Python 3 (ipykernel)",
  125. "language": "python",
  126. "name": "python3"
  127. },
  128. "language_info": {
  129. "codemirror_mode": {
  130. "name": "ipython",
  131. "version": 3
  132. },
  133. "file_extension": ".py",
  134. "mimetype": "text/x-python",
  135. "name": "python",
  136. "nbconvert_exporter": "python",
  137. "pygments_lexer": "ipython3",
  138. "version": "3.10.12"
  139. }
  140. },
  141. "nbformat": 4,
  142. "nbformat_minor": 4
  143. }