You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
4.5 KiB

2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
2 months ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example usage of Tempestpy"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": null,
  13. "metadata": {
  14. "vscode": {
  15. "languageId": "plaintext"
  16. }
  17. },
  18. "outputs": [],
  19. "source": [
  20. "from sb3_contrib import MaskablePPO\n",
  21. "from sb3_contrib.common.wrappers import ActionMasker\n",
  22. "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
  23. "\n",
  24. "import gymnasium as gym\n",
  25. "\n",
  26. "from minigrid.core.actions import Actions\n",
  27. "from minigrid.core.constants import TILE_PIXELS\n",
  28. "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
  29. "\n",
  30. "import tempfile, datetime, shutil\n",
  31. "\n",
  32. "import time\n",
  33. "import os\n",
  34. "\n",
  35. "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
  36. "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
  37. "\n",
  38. "import os, sys\n",
  39. "from copy import deepcopy\n",
  40. "\n",
  41. "from PIL import Image"
  42. ]
  43. },
  44. {
  45. "cell_type": "code",
  46. "execution_count": null,
  47. "metadata": {
  48. "vscode": {
  49. "languageId": "plaintext"
  50. }
  51. },
  52. "outputs": [],
  53. "source": [
  54. "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
  55. "\n",
  56. "def mask_fn(env: gym.Env):\n",
  57. " return env.create_action_mask()\n",
  58. "\n",
  59. "def nomask_fn(env: gym.Env):\n",
  60. " return [1.0] * 7\n",
  61. "\n",
  62. "def main():\n",
  63. " env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n",
  64. " \n",
  65. " formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
  66. " value_for_training = 0.99\n",
  67. " shield_comparison = \"absolute\"\n",
  68. " shielding = ShieldingConfig.Training\n",
  69. " \n",
  70. " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
  71. " \n",
  72. " env = gym.make(env, render_mode=\"rgb_array\")\n",
  73. " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
  74. " env = RGBImgObsWrapper(env, 8)\n",
  75. " env = ImgObsWrapper(env)\n",
  76. " env = MiniWrapper(env)\n",
  77. "\n",
  78. " \n",
  79. " env.reset()\n",
  80. " Image.fromarray(env.render()).show()\n",
  81. " \n",
  82. " shield_handlers = dict()\n",
  83. " if shield_needed(shielding):\n",
  84. " for value in [0.9, 0.95, 0.99, 1.0]:\n",
  85. " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)\n",
  86. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  87. " create_shield_overlay_image(image_env, shield_handler.create_shield())\n",
  88. " print(f\"The shield for shield_value = {value}\")\n",
  89. "\n",
  90. " shield_handlers[value] = shield_handler\n",
  91. "\n",
  92. "\n",
  93. " if shielding == ShieldingConfig.Training:\n",
  94. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
  95. " env = ActionMasker(env, mask_fn)\n",
  96. " print(\"Training with shield:\")\n",
  97. " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
  98. " elif shielding == ShieldingConfig.Disabled:\n",
  99. " env = ActionMasker(env, nomask_fn)\n",
  100. " else:\n",
  101. " assert(False) \n",
  102. " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
  103. " model.set_logger(logger)\n",
  104. " steps = 20_000\n",
  105. "\n",
  106. " #assert(False)\n",
  107. " model.learn(steps,callback=[InfoCallback()])\n",
  108. "\n",
  109. "\n",
  110. "\n",
  111. "if __name__ == '__main__':\n",
  112. " print(\"Starting the training\")\n",
  113. " main()"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": null,
  119. "metadata": {},
  120. "outputs": [],
  121. "source": []
  122. }
  123. ],
  124. "metadata": {
  125. "kernelspec": {
  126. "display_name": "Python 3 (ipykernel)",
  127. "language": "python",
  128. "name": "python3"
  129. },
  130. "language_info": {
  131. "codemirror_mode": {
  132. "name": "ipython",
  133. "version": 3
  134. },
  135. "file_extension": ".py",
  136. "mimetype": "text/x-python",
  137. "name": "python",
  138. "nbconvert_exporter": "python",
  139. "pygments_lexer": "ipython3",
  140. "version": "3.10.12"
  141. }
  142. },
  143. "nbformat": 4,
  144. "nbformat_minor": 4
  145. }