You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

199 lines
28 KiB

2 days ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example usage of Tempestpy"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 1,
  13. "metadata": {
  14. "vscode": {
  15. "languageId": "plaintext"
  16. }
  17. },
  18. "outputs": [
  19. {
  20. "name": "stdout",
  21. "output_type": "stream",
  22. "text": [
  23. "pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)\n",
  24. "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
  25. ]
  26. },
  27. {
  28. "name": "stderr",
  29. "output_type": "stream",
  30. "text": [
  31. "2024-11-29 12:18:59.459471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
  32. "2024-11-29 12:18:59.474489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
  33. "2024-11-29 12:18:59.478811: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
  34. "2024-11-29 12:18:59.488641: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
  35. "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
  36. "2024-11-29 12:19:00.368388: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
  37. "error: XDG_RUNTIME_DIR not set in the environment.\n"
  38. ]
  39. }
  40. ],
  41. "source": [
  42. "from sb3_contrib import MaskablePPO\n",
  43. "from sb3_contrib.common.wrappers import ActionMasker\n",
  44. "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
  45. "\n",
  46. "import gymnasium as gym\n",
  47. "\n",
  48. "from minigrid.core.actions import Actions\n",
  49. "from minigrid.core.constants import TILE_PIXELS\n",
  50. "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
  51. "\n",
  52. "import tempfile, datetime, shutil\n",
  53. "\n",
  54. "import time\n",
  55. "import os\n",
  56. "\n",
  57. "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
  58. "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
  59. "\n",
  60. "import os, sys\n",
  61. "from copy import deepcopy\n",
  62. "\n",
  63. "from PIL import Image"
  64. ]
  65. },
  66. {
  67. "cell_type": "code",
  68. "execution_count": null,
  69. "metadata": {
  70. "vscode": {
  71. "languageId": "plaintext"
  72. }
  73. },
  74. "outputs": [
  75. {
  76. "name": "stdout",
  77. "output_type": "stream",
  78. "text": [
  79. "Starting the training\n"
  80. ]
  81. },
  82. {
  83. "data": {
  84. "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAEgAWADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDDooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA8y1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPpaiiivzQ/IAooooA8A1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOv0s/Xy5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPZKKKKACiiigAooooAKKKKAPJ9X/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQB9LUUUV+aH5AFFFFAHzrq//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6/Sz9fCiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigC5q//Iav/wDr4k/9CNU6uav/AMhq/wD+viT/ANCNU6ACiiigD2SiiigAooooAKKKKACiiigDzLVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6ALn9r6n/0Ebv8A7/t/jR/a+p/9BG7/AO/7f41TooA19U1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/wDQRu/+/wC3+NGr/wDIav8A/r4k/wDQjVOgC5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPZKKKKACiiigAooooAKKKKAPJ9X/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAXNX/AOQ1f/8AXxJ/6Eap1c1f/kNX/wD18Sf+hGqdABRRRQBc1f8A5DV//wBfEn/oRqnVzV/+Q1f/APXxJ/6Eap0AFFFFAFzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnQAUUUUAeyUUUUAFFFFABRRRQAUUUUAeZapqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/+gjd/9/2/xo1f/kNX/wD18Sf+hGqdAFz+19T/AOgjd/8Af9v8aP7X1P8A6CN3/wB/2/xqnRQBr6pqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/+gjd/9/2/xo1f/kNX/wD18Sf+hGqdAFz+19
  85. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAEgCAIAAAAFWz53AAAR8ElEQVR4Ae2d4XnUSBKGxzz8vksAXwC7CeBN4DYBNoLZADAJkMCZAJiLgAQggTMJ7AWwJoEjgb1PFm6Klrqnx6WR5NKrxw+0uqtKU2/1fJZa1szFfr/fsUEAAhAYI/BsrJM+CEAAAh0BBIJ5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgAACwRyAAASKBBCIIhoGIAABBII5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgMBzEMQmcHh/OGuC+9/3Z41P8GUJzCEQh8MZ5+h+vyd+bQ69rw1OMgb/CsYZ5mfl6P4hLjH8DIkAgbAEEIiwpSUxCPgJIBB+hkSAQFgCCETY0pIYBPwEEAg/QyJAICwBBCJsaUkMAn4CCISfIREgEJYAAhG2tCQGAT8BBMLPkAgQCEsAgQhbWhKDgJ8AAuFnSAQIhCWAQIQtLYlBwE8AgfAzJAIEwhJAIMKWlsQg4CeAQPgZEgECYQkgEGFLS2IQ8BNAIPwMiQCBsAQQiLClJTEI+AkgEH6GRIBAWAIIRNjSkhgE/AQQCD9DIkAgLAEEImxpSQwCfgIIhJ8hESAQlgACEba0JAYBPwEEws+QCBAISwCBCFtaEoOAnwAC4WdIBAiEJYBAhC0tiUHATwCB8DMkAgTCEkAgwpaWxCDgJ3Chryf3RyECBCAQksDzGbI6HA7nO4oEjvgVvPCpwNFQAD71BJ2jXGI4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJMAAuEEiDsEIhNAICJXl9wg4CSAQDgB4g6ByAQQiMjVJTcIOAkgEE6AuEMgMgEEInJ1yQ0CTgIIhBMg7hCITACBiFxdcoOAkwAC4QSIOwQiE0AgIleX3CDgJIBAOAHiDoHIBBCIyNUlNwg4CSAQToC4QyAyAQQicnXJDQJOAgiEEyDuEIhMAIGIXF1yg4CTAALhBIg7BCITQCAiV5fcIOAkgEA4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJPAhb7+3Bkitvv7F4fYCW48u9+/MP9rU+B5bXCiscPhjO8xCdxZ479/OxEFwqyVwFnnz7nn57l/wXOJsdZpy+uCwAoIIBArKAIvAQJrJYBArLUyvC4IrIAAArGCIvASILBWAgjEWivD64LACgggECsoAi8BAmslgEA0Vebuf7vbu+OWsmk0U0C7Ed/SEMOZ+dij07YEEAhLo9h+82n3y793V4fi+19zWqOyaTRTQLsRv6eRMM7Mx9aCtiUwxx9K2eM90fbrq93d193nL937/+WL3c0/d1eX31LRnL7+1A1pe/G37t+jZoqggHYjvsU4Px9bC9qWwBx/av2k/1Ltr7ffcWWT+Ppqd3P7XRr0JlePNnW+u919+dq1NdetWSYunYXZiD8/H/2p9ZOen+f+S0oEwrxBx5pWIPpx+zZWj84akjTYAFYm1F+f+taR+JbGsD0tHwRiSNj2IBCWxkh7KBC9kaapThNeXn47axjxvO/qTjHuOgVJlyQly6yf+BmQbHcqPghEBjbbZQ0iA9K6qzd8y3u+u+j4cbmh8QDEr4M6N5/60bczyl2Mplpff9xd3nSLC5VNv9N0I6Nyp6P3VRCFUkC7ET/RWIRPOjqNjABnEBmQ4q4WHd986i4rhisOkoZ0I0P+wzsdfVBN/bR4OTwM8ZflM6wIPSLAGsSRaZDWIKwKpIVJ29kvQypcEou0MGmnfurMDmxDEX94+XYmPqxBZPMw20UgMiD5bhKIfsBO02Q6fM83mqUIqdHo2GiWwqZGo2OjWQqbGo2OjWYpbGo0OjaaKSwCkdiONrjEGMVS7NRvttt99/eU/WnCUBp6z0az4WEaHRvNiF8v05APPRkBBCID0rTbvz+PmjaaDeM0OjaaEX9IgJ5GAtzFaALVnTJ8zJ8gGnpqrUE/9U2PISmUAtqN+InGInzS0WlkBDiDyICM7+ruw4f/7t593r1+2d3FuPx7bmaXIUfvdMhBU19DCqJN9yz0+z9txBeKBfmkQtDICLBImQHJd/tFSjt3ZWFlwkqDliS0pQe30g3Rint/vIoB8YXofHxYpOxnYOlfBKJE5lu/vYuRTdNXP+1uv3x/KCs94pmWMBVCNyyvXnRnH/1mleVbl/mP+PPzQSDMBBxpIhAjUGyXFYi+P3sbl25kWJmQY33q2yMS39IYtqflg0AMCdseBMLSGGkPBaI36qfpq59/WEoY+ksmPvwxvmwxNLY9xLc0hu2p+CAQQ7a2h0VKS+OEttYpb349bq+VSLsYedzhwYL4DyTG/z83n/Gjbq+X25zbqzkZQ6CZAALRjApDCGyPAAKxvZqTMQSaCSAQzagwhMD2CMxxF+NJU33/4vCkXz8vvk5AdzHqBhsfneMuxpP+1OD3bzc+Q+Kn/6Tn57k/1ZpLjPhvADKEwKMJIBCPRocjBOIT2JpA6Gmq+weq4leWDCEwAYGtCcRrPRUxATZCQGAbBLYmEK92O/2wQQACTQQ2JRA6d7i8/+EkomlyYASBTQlEOndIjdYJoGcH9Vzm0U02jWYKaDfiWxpiODMfe3TalsB2BOKl+Q48fRmedk/Y9JU5+jqcyrdmaU5rVDaNZgpoN+L3NBLGmfnYWtC2BOb4Qyl7vOXa2VmDdu8/HLLtBb2+2t197T5LbvitWZrT6Zty9PlR2o6a6TNmFNBuxLcY5+dja0HbEpjjT61X8JdqurU5vELQesQXy2K0bT8wJpvE+mJefWDk8BMos0+RtGalj5/qD038+fnwgTGj0z51bkQgtCp5k3J+aFzvdu8e2sX/rUD0RvZtrJ70NXlZCCsTGqpPfetLfEtj2J6WDwIxJGx7NiIQf97fvLCJq61zin9kXcPdoUD0Npqm+qz6l5c7nSBUtu4U4667oDj1c6WIX6Gqoan4IBB1zltYg9B6pK4mhps6NXTCSoQN0fhZcp18VBXExrRt4lsaw/a5+QyPuM2eLdzF0KVEaasM/eCi78K6vDnyrVn6naYbGZU7HX1EnVMolALajfiJxiJ80tFpZATCn0FoeTK7f2EJaEgGx5cq5aPvwnrzqbus0PVCdlkhaUg3MmQ5vNPRHzJblbCvg/gisCyfrBzs9gTCr0H8a7ernyZo8fJNZTakNQirAmlh0nb2y5AKlcQiLUzaqZ86s4PaUMQfLtmciQ9rENk8zHbDC8SfhQWIxOHIUmUSiN7BTtMUYviebzRLEVKj0bHRLIVNjUbHRrMUNjUaHRvNUtjUaHRsNFNYBCKxHW3EvsTQFYRWIuubDGT2oW6URvWb7XbfLaH3pwlDaegtG81S2NRodGw0S2FTo9Gx0SyFTY1Gx0azFDY1Gh0bzVJYGiUCsQWi8aEsmbUKRM+xn38lpqm/0SzZp0ajY6NZCpsajY6NZilsajQ6NpqlsKnR6NholsLSGBJ4NuyK0qNbmI03GGUm49rWnTJ8zJ8gGjporUE/9U2PISmUAtqN+InGInzS0WlkBAKfQVRuXmQQtCvj2h9E6OaFvqH73efid/DaZcjROx06hqa+hhREm+6J6Pdb2ogvFAvySYWgkREIvEj5V5bqsd2LUYN+kdLOXZnZr+q20qAlCW3DpzMq7v1BKwbEF6Lz8WGRsp+BpX+jCsTowxclCH2/7oaOPJrRC0RvkU3TVz/tbr905wLa7GplWsJUv25YXr3ozj76zSrLty7zH/Hn54NAmAk40owqEP9pXoBIULR48EvaSQ0rEH1n9ja20pC81LAyod361LeOxLc
  86. "text/plain": [
  87. "<PIL.Image.Image image mode=RGB size=352x288>"
  88. ]
  89. },
  90. "metadata": {},
  91. "output_type": "display_data"
  92. },
  93. {
  94. "name": "stdout",
  95. "output_type": "stream",
  96. "text": [
  97. "\n",
  98. "\n",
  99. "Computing new shield\n",
  100. "LOG: Starting with explicit model creation...\n"
  101. ]
  102. }
  103. ],
  104. "source": [
  105. "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
  106. "\n",
  107. "def mask_fn(env: gym.Env):\n",
  108. " return env.create_action_mask()\n",
  109. "\n",
  110. "def nomask_fn(env: gym.Env):\n",
  111. " return [1.0] * 7\n",
  112. "\n",
  113. "def main():\n",
  114. " #env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n",
  115. " env = \"MiniGrid-WindyCity-Adv-v0\"\n",
  116. "\n",
  117. " formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
  118. " value_for_training = 0.99\n",
  119. " shield_comparison = \"absolute\"\n",
  120. " shielding = ShieldingConfig.Training\n",
  121. " \n",
  122. " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
  123. " \n",
  124. " env = gym.make(env, render_mode=\"rgb_array\")\n",
  125. " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
  126. " env = RGBImgObsWrapper(env, 8)\n",
  127. " env = ImgObsWrapper(env)\n",
  128. " env = MiniWrapper(env)\n",
  129. "\n",
  130. " \n",
  131. " env.reset()\n",
  132. " Image.fromarray(env.render()).show()\n",
  133. " \n",
  134. " shield_handlers = dict()\n",
  135. " if shield_needed(shielding):\n",
  136. " for value in [0.9, 0.95, 0.99, 1.0]:\n",
  137. " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
  138. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  139. "\n",
  140. "\n",
  141. " shield_handlers[value] = shield_handler\n",
  142. " if shield_needed(shielding):\n",
  143. " for value in [0.9, 0.95, 0.99, 1.0]: \n",
  144. " create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
  145. " print(f\"The shield for shield_value = {value}\")\n",
  146. "\n",
  147. " if shielding == ShieldingConfig.Training:\n",
  148. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
  149. " env = ActionMasker(env, mask_fn)\n",
  150. " print(\"Training with shield:\")\n",
  151. " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
  152. " elif shielding == ShieldingConfig.Disabled:\n",
  153. " env = ActionMasker(env, nomask_fn)\n",
  154. " else:\n",
  155. " assert(False) \n",
  156. " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
  157. " model.set_logger(logger)\n",
  158. " steps = 200\n",
  159. "\n",
  160. " assert(False)\n",
  161. " model.learn(steps,callback=[InfoCallback()])\n",
  162. "\n",
  163. "\n",
  164. "\n",
  165. "if __name__ == '__main__':\n",
  166. " print(\"Starting the training\")\n",
  167. " main()"
  168. ]
  169. },
  170. {
  171. "cell_type": "code",
  172. "execution_count": null,
  173. "metadata": {},
  174. "outputs": [],
  175. "source": []
  176. }
  177. ],
  178. "metadata": {
  179. "kernelspec": {
  180. "display_name": "Python 3 (ipykernel)",
  181. "language": "python",
  182. "name": "python3"
  183. },
  184. "language_info": {
  185. "codemirror_mode": {
  186. "name": "ipython",
  187. "version": 3
  188. },
  189. "file_extension": ".py",
  190. "mimetype": "text/x-python",
  191. "name": "python",
  192. "nbconvert_exporter": "python",
  193. "pygments_lexer": "ipython3",
  194. "version": "3.10.12"
  195. }
  196. },
  197. "nbformat": 4,
  198. "nbformat_minor": 4
  199. }