The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

250 lines
18 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example usage of Tempestpy"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 1,
  13. "metadata": {
  14. "vscode": {
  15. "languageId": "plaintext"
  16. }
  17. },
  18. "outputs": [
  19. {
  20. "name": "stdout",
  21. "output_type": "stream",
  22. "text": [
  23. "pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)\n",
  24. "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
  25. ]
  26. },
  27. {
  28. "name": "stderr",
  29. "output_type": "stream",
  30. "text": [
  31. "2024-09-24 13:10:36.778354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
  32. "2024-09-24 13:10:36.792058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
  33. "2024-09-24 13:10:36.796147: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
  34. "2024-09-24 13:10:36.806045: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
  35. "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
  36. "2024-09-24 13:10:37.640233: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
  37. "error: XDG_RUNTIME_DIR not set in the environment.\n"
  38. ]
  39. }
  40. ],
  41. "source": [
  42. "from sb3_contrib import MaskablePPO\n",
  43. "from sb3_contrib.common.wrappers import ActionMasker\n",
  44. "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
  45. "\n",
  46. "import gymnasium as gym\n",
  47. "\n",
  48. "from minigrid.core.actions import Actions\n",
  49. "from minigrid.core.constants import TILE_PIXELS\n",
  50. "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
  51. "\n",
  52. "import tempfile, datetime, shutil\n",
  53. "\n",
  54. "import time\n",
  55. "import os\n",
  56. "\n",
  57. "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
  58. "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
  59. "\n",
  60. "import os, sys\n",
  61. "from copy import deepcopy\n",
  62. "\n",
  63. "from PIL import Image"
  64. ]
  65. },
  66. {
  67. "cell_type": "code",
  68. "execution_count": null,
  69. "metadata": {
  70. "vscode": {
  71. "languageId": "plaintext"
  72. }
  73. },
  74. "outputs": [
  75. {
  76. "data": {
  77. "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCADAAMADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDDooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAPN9Qu9XbU9Q+z3V4Yop3BEcrYUbjjgHpxWf/a+p/8AQRu/+/7f4102kf8AIa1v/r4P/oT1cvNIsr7JlhAc/wAacH/6/wCNcVTGKnUcJLQ+nwnDc8ZgoYmjP3nfR7aNrf5djjf7X1P/AKCN3/3/AG/xo/tfU/8AoI3f/f8Ab/GtO88L3EWWtZBMv90/K3+BrEmhlt5Ck0bIw7MMV0060Knws8XF5disG7V4NefT79j6TorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfAxw8ZJPnS++/wB3X5fgfWyryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8AIq5p1hcQapHPq2tGe/dG8u1jYRxKvfanVsepolh4xTfOn99/u6fP8QjXk2lytfdb7+vy/A8e1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6/Qj4Uuf2vqf8A0Ebv/v8At/jR/a+p/wDQRu/+/wC3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/0Ebv8A7/t/jRq//Iav/wDr4k/9CNU6ALn9r6n/ANBG7/7/ALf40f2vqf8A0Ebv/v8At/jVOigD2SiiigAooooAKKKKACiiigDjdI/5DWt/9fB/9Cerl5q9lY5EswLj+BOT/wDW/GuS1K5mh1fUkilZFe4fcFOM4Y/41nVxVMGqlRzk9D6fCcSTweChhqMPeV9Xtq29vn3N+88UXEuVtYxCv94/M3+ArEmmluJC80jOx7sc1HRXTTowp/CjxcXmOKxjvXm35dPu2Pa9ej8O3etJbXs0lhqgUNDeITET7B+h+h/Cq1/p+m6ckc/irXJdR2/6m3cbVb0/dr94+549a628sbXUbdre8t454j1WRcis7S/C2j6PKZrS0HndpJCXZR2AJ6DtxXxFPFRjBJylp00/B7x9NT6ephpSm2ktfX8VtL8DDsNP03UUkn8K65Lp27/XW6Dcq+v7tvun3HHpVnQY/DtprT21lNJf6oVLTXjkykexfoPoPxrT1Twto+sSia7tB53eSMlGYdwSOo7c1o2dja6dbrb2dvHBEOixrgUqmKjKDSlLX0/F7y/AKeGlGabS09fwW0fxPnzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnX3h8WFFFFAFzV/+Q1f/APXxJ/6Eap1c1f8A5DV//wBfEn/oRqnQAUUUUAeyUUUUAFFFFABRRRQAUUUUAeZapqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAGvqmqahHq96iX90qLO4VVmYADceBzVT+19T/AOgjd/8Af9v8aNX/AOQ1f/8AXxJ/6Eap0AXP7X1P/oI3f/f9v8aP7X1P/oI3f/f9v8ap0UAa+qapqEer3qJf3Sos7hVWZgANx4HNVP7X1P8A6CN3/wB/2/xo1f8A5DV//wBfEn/oRqnQBc/tfU/+gjd/9/2/xo/tfU/+gjd/9/2/xqnRQBr6pqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAHslFFFABRRRQAUUUUAFFFFAHk+r/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigC5q/wDyGr//AK+JP/QjVOrmr/8AIav/APr4k/8AQjVOgAooooAuav8A8hq//wCviT/0I1Tq5q//ACGr/wD6+JP/AEI1ToAKKKKALmr/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigD2SiiigAooooAKKKKACiiigDzLVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/wBBG7/7/t/jRq//ACGr/wD6+JP/AEI1ToAuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigD6XorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfnscPGST50vvv93X5fgfdSryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8irmnWFxBqkc+ra0Z790by7WNhHEq99qdWx6miWHjFN86f33+7p8/xCNeTaXK191vv6/L8Dx7VNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6/Qj4Uuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6ALn9r6n/0Ebv/AL/t/jR/a+p/9BG7/wC/7f41TooA9kooooAKKKKACiiigAooooA8n1f/AJDV/wD9fEn/AKEap1c1f/kNX/8A18Sf+hGqdABRRRQB7Xr0fh271pLa9mksNUChobxCYifYP0P0P4VWv9P03Tkjn8Va5LqO3/U27jaren7tfvH3PHrXW3lja6jbtb3lvHPEeqyLkVnaX4W0fR5TNaWg87tJIS7KOwBPQduK+Ep4qMYJOUtOmn4PePpqfaVMNKU20lr6/itpfgYdhp+m6ikk/hXXJdO3f663QblX1/dt90+449Ks6DH4dtNae2sppL/VCpaa8cmUj2L9B9B+NaeqeFtH1iUTXdoPO7yRkozDuCR1HbmtGzsbXTrdbezt44Ih0WNcClUxUZQaUpa+n4veX4BTw0ozTaWnr+C2j+J8+av/AMhq/wD+viT/ANCNU6uav/yGr/8A6+JP/QjVOvvD4sKKKKALmr/8hq//AOviT/0I1Tq5q/8AyGr/AP6+JP8A0I1ToAKKKKAPZKKKKACiiigAooooAKKKKAPMtU1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyG
  78. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAADACAIAAADdvvtQAAAKN0lEQVR4Ae2d8VUctxaH5Zz8/dIApICkATYNOA3gBoILMG7ADRgKgDQQGrAbCDSQFBBo4LmB937jS5TJzI4kVtxhV/p05pi1dHVX+uk7d3SXRfPq7OwsUFBgVwW+2bUj/VBgUACA4KBKAQCqko/OAAQDVQoAUJV8dAYgGKhSAICq5KMzAMFAlQIAVCUfnQEIBqoUAKAq+egMQDBQpQAAVclHZwCCgSoFAKhKPjoDEAxUKQBAVfLRGYBgoEoBAKqSj84ABANVCgBQlXx0BiAYqFIAgKrkozMAwUCVAgBUJR+dAQgGqhQAoCr56AxAMFClAABVyUdnAIKBKgUAqEo+On976BJcHV27TuHtA+cnpQReA6Dra8c1vvqQmt6ztLmOXwd8eft/FhGWnHALW1KG+iIFAKhIJoyWFACgJWWoL1IAgIpkwmhJgVUB+i2Ek6WBUH+YCqwK0GkItyH8HsK7wxSLUc8VWBUge/tNCBch/C+Ej88XkO7/G27v57Ob1sim0EwOKSUKvABAcVjnfwckRabK8v5z+OnXsLle5EPcqFU2hWZySClRYI0PEtPjUEDSpfBxE8JlCA9p64XWd5tw/yXcPQx8nByFi9dhc/xoKnTOPw9NKkf/Gf7NmsmDHFJKFHh5gGyUWm4FJF3CSDe4u5Kxj2yEy+3ZEH6MFcPofBMubv9BR0yoRkWVl1/r52YT+EbvwMvtCuwLQHF0up3pUkBSNBJMTwpIE4zeqP/XqBPRsXcRRsaWYWRmoBOX4Ekv9g4gG70CkuKQXcLgSQEpYiQ+To4fo85clIjR3f1ww4q3vLklNQkF9hSgOGK7ryn5F0aKSeVFQJQwMdzU2O6Uyzqz/GZWs48VWmJFo7++Zv5H2wZ4/ikcXwybm0SxRCyRqVlfOZErOaSUKLDvEahkDmbz8CW8/zzsjic7HrXGzbVZzjM1q7fNtfxQyhU4DICyt7CLn8Ppj48p2BijMTq2TZY040zNEv4xOuymy+mR5b4DpDtX4SY67p2ND2Gky8qEiUnCH/WamMV6XiQU2FOAdkvjNc8JRktMFJolhKPJFNg7gBRvFHWelLfP19L4mNdPagrNJr3471iBfcnCFHLEjT7+ebMTPcNe51PI/gZUex1d6SInciWHlBIFXj4CaUHtQ+eS4S7ZKPm6+TNc3oV3J0MWdvzd1HC8Td6aqamD0FGTnKgoF1N8omQVeEmAyjfI2Wl8fD38olRrb9cYozE62hKp6Jep40xNNWN09F/rPphScgq8AEDZnDw35i3tCjnK5BV7LIQYRqc/hNuHIZaojHfTMbc3jDZHQ/SyMibvsYofSQVWBehZNsiJ6UwwMizG6FjfSQpmZqCTEDbRtCpA2iCvUMYY6dPFpa1MxOjmj+3bphWG2sBbrArQmnoZRtl3FEZLhGX7YiAF9iWNZzEOVAEAOtCF25dhA9C+rMSBjgOADnTh9mXYr3S8yL6MhXEcoAJrZGHe59/gPwGed4DgFpYQn6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCB38+0NXRdWJ69U1vHzg/KaXiwZ8PdPUhNb1naTv084eeRYQlJ9zClpShvkgBACqSCaMlBQBoSRnqixQAoCKZMFpSAICWlKG+SAEAKpIJoyUFGgFIT4wreUilbArN5JBSokAjAOnRcT/9GjbXi3yIG7XKptBMDiklCqzxQWLJOCpt9KzC+y/DsyzFx+QJc0LHniSvt9BjMVWyZvIgh5QSBRoBKD49zlgxjM43wzOahYuK0BETqlGxp6hGjMZmE/gGa0pSgUYAsjlOMHqjR2z+Gx0zEzEGjR6wKozMDHRMnKf+2xRANvmIkfg4OX6MOnNdIkZ390Nw4rmFc4lKahoEKGJUwoQwCmx3SkhZsGkkCzv/FI4vhs1NolgilsjUrK+cyJUcUkoUaCcC6fnw7z8Pz42Pm+U4/3Eipsp5pmaWtrm258zHvrxIK9BIBLr4Ofz+y5DAG0YxGsWPf7RZVqtszMxSsBiNLOqIP3U3MzmklCjQTgSKe2fL5EWDLiuTDOv2bPi8MSb8UaaJWaznRUKBdgCySU4wWmKi0CwhHE2mQGsAjTHKrrFhlDXDIKFAI3ug4Zb0KWR/A6q9jq50kRO5kkNKiQKNRCAlXzd/hsu78O5kyMKOv5vOfZxhbc3U1EHoqElOVLSbVnyiZBVoJAJ9fD2go6Ll//7yX9FokmFpVzTJ1NTLoo46Gj1yJYeUEgUaiUAKOUq8FXsshIgDXac/hNuHAReV8W46pmBK02S/ORqil5WlAPbYzI+ZAo0AZPOaYGRYjNExs0kKZmagM2OjqKIpgOYYnf64uJWJGN38sX3bVKRf90YNAjTGKLu+wojNclalhEEjm+jEDGlyVQCAXOVt3zkAtb/GrjMEIFd523e+xib67OzMU0jf84E0ctfxX19dhytPed56Og9hDYBcz9e5+uArkLy7jt+XHndtArcwf42bfgcAanp5/ScHQP4aN/0OANT08vpPDoD8NW76HQCo6eX1n1wjAOkbYSVfQpVNoZkcUkoUaAQgfTVMfy4Y/85rPnNxw/lAc1nqa9b4ILF+lFkP+i4i5wNlVfIwaASg+O2w+OeC+iKineGiP0JV4XwgD3rksxGATJ0JRpwP5ATN2G1TAE0w0hfmOR9ovNgerxsEKGJU8l1VzgeqpKqRLIzzgSo52Ll7OxHI/lxQty1lZENcGRXl8La5tjrOBxppU/uykQjE+UC1IOzav50INEnB9NGiLiuTvy3kfKBdadnSrx2AbHITjCboRAEKzaI9L5YUaA0gm6fxsTTnWF9oFu15MVegkT3QsE3mfKD58vrXNBKBlHxxPpA/LVveoZEI9JHzgbYs7hpVjUSgycEunA+0Bjtf36MRgEyvCUacD7QCRk0BNMeI84G8GWoQoDFGWfmUyeui7KxAI5vonedPx0oFAKhSwN67A1DvBFTOH4AqBey9+yvXw296V7eD+ROBOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9p/h/BVTY14Za+YMAAAAASUVORK5CYII=",
  79. "text/plain": [
  80. "<PIL.Image.Image image mode=RGB size=192x192>"
  81. ]
  82. },
  83. "metadata": {},
  84. "output_type": "display_data"
  85. },
  86. {
  87. "name": "stdin",
  88. "output_type": "stream",
  89. "text": [
  90. " \n"
  91. ]
  92. },
  93. {
  94. "name": "stdout",
  95. "output_type": "stream",
  96. "text": [
  97. "\n",
  98. "\n",
  99. "Computing new shield\n",
  100. "LOG: Starting with explicit model creation...\n",
  101. "Elapsed time is 0.003572702407836914 seconds.\n",
  102. "LOG: Starting with model checking...\n",
  103. "Elapsed time is 0.0002765655517578125 seconds.\n",
  104. "LOG: Starting to translate shield...\n",
  105. "Write to file shielding_files_20240924T131106_4c2n0rfg/shield.\n",
  106. "Elapsed time is 0.0036995410919189453 seconds.\n",
  107. "\n",
  108. "\n",
  109. "Computing new shield\n",
  110. "LOG: Starting with explicit model creation...\n",
  111. "Elapsed time is 0.001905202865600586 seconds.\n",
  112. "LOG: Starting with model checking...\n",
  113. "Elapsed time is 0.0001571178436279297 seconds.\n",
  114. "LOG: Starting to translate shield...\n",
  115. "Elapsed time is 0.0028374195098876953 seconds.\n",
  116. "Symbolic Description of the Model:\n",
  117. "Write to file shielding_files_20240924T131106_mkizoz41/shield.\n",
  118. "mdp\n",
  119. "\n",
  120. "formula AgentCannotMoveEastWall = (colAgent=4&rowAgent=1) | (colAgent=4&rowAgent=2) | (colAgent=4&rowAgent=3) | (colAgent=4&rowAgent=4);\n",
  121. "formula AgentCannotMoveNorthWall = (colAgent=3&rowAgent=1) | (colAgent=4&rowAgent=1) | (colAgent=1&rowAgent=1) | (colAgent=2&rowAgent=1);\n",
  122. "formula AgentCannotMoveSouthWall = (colAgent=1&rowAgent=4) | (colAgent=3&rowAgent=4) | (colAgent=4&rowAgent=4) | (colAgent=2&rowAgent=4);\n",
  123. "formula AgentCannotMoveWestWall = (colAgent=1&rowAgent=2) | (colAgent=1&rowAgent=3) | (colAgent=1&rowAgent=4) | (colAgent=1&rowAgent=1);\n",
  124. "formula AgentIsOnSlippery = false;\n",
  125. "formula AgentIsOnLava = (colAgent=2&rowAgent=1) | (colAgent=2&rowAgent=3) | (colAgent=2&rowAgent=4);\n",
  126. "formula AgentIsOnGoal = (colAgent=4&rowAgent=4);\n",
  127. "init\n",
  128. " true\n",
  129. "endinit\n",
  130. "\n",
  131. "\n",
  132. "module Agent\n",
  133. " colAgent : [1..4];\n",
  134. " rowAgent : [1..4];\n",
  135. " viewAgent : [0..3];\n",
  136. "\n",
  137. " [Agent_turn_right] !AgentIsOnLava &true -> 1.000000: (viewAgent'=mod(viewAgent+1,4));\n",
  138. " [Agent_turn_left] !AgentIsOnLava &viewAgent>0 -> 1.000000: (viewAgent'=viewAgent-1);\n",
  139. " [Agent_turn_left] !AgentIsOnLava &viewAgent=0 -> 1.000000: (viewAgent'=3);\n",
  140. " [Agent_move_North] viewAgent=3 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveNorthWall -> 1.000000: (rowAgent'=rowAgent-1);\n",
  141. " [Agent_move_East] viewAgent=0 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveEastWall -> 1.000000: (colAgent'=colAgent+1);\n",
  142. " [Agent_move_South] viewAgent=1 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveSouthWall -> 1.000000: (rowAgent'=rowAgent+1);\n",
  143. " [Agent_move_West] viewAgent=2 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveWestWall -> 1.000000: (colAgent'=colAgent-1);\n",
  144. "endmodule\n",
  145. "\n",
  146. "\n"
  147. ]
  148. }
  149. ],
  150. "source": [
  151. "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
  152. "\n",
  153. "def mask_fn(env: gym.Env):\n",
  154. " return env.create_action_mask()\n",
  155. "\n",
  156. "def nomask_fn(env: gym.Env):\n",
  157. " return [1.0] * 7\n",
  158. "\n",
  159. "def main():\n",
  160. " env = \"MiniGrid-LavaGapS6-v0\"\n",
  161. "\n",
  162. " # TODO Change the safety specification\n",
  163. " formula = \"Pmax=? [G !AgentIsOnLava]\"\n",
  164. " value_for_training = 1.0\n",
  165. " shield_comparison = \"absolute\"\n",
  166. " shielding = ShieldingConfig.Training\n",
  167. " \n",
  168. " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
  169. " \n",
  170. "\n",
  171. " env = gym.make(env, render_mode=\"rgb_array\")\n",
  172. " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
  173. " env = RGBImgObsWrapper(env, 8)\n",
  174. " env = ImgObsWrapper(env)\n",
  175. " env = MiniWrapper(env)\n",
  176. "\n",
  177. " \n",
  178. " env.reset()\n",
  179. " Image.fromarray(env.render()).show()\n",
  180. " input(\"\") \n",
  181. " \n",
  182. " shield_handlers = dict()\n",
  183. " if shield_needed(shielding):\n",
  184. " for value in [0.0, 1.0]:\n",
  185. " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
  186. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  187. " shield_handlers[value] = shield_handler\n",
  188. "\n",
  189. " print(\"Symbolic Description of the Model:\")\n",
  190. " shield_handlers[1.0].print_symbolic_model()\n",
  191. " input(\"\")\n",
  192. "\n",
  193. " if shield_needed(shielding):\n",
  194. " for value in [1.0]:\n",
  195. " create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
  196. " print(f\"The shield for shield_value = {value}\")\n",
  197. " input(\"\")\n",
  198. " \n",
  199. " if shielding == ShieldingConfig.Training:\n",
  200. " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
  201. " env = ActionMasker(env, mask_fn)\n",
  202. " print(\"Training with shield:\")\n",
  203. " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
  204. " elif shielding == ShieldingConfig.Disabled:\n",
  205. " env = ActionMasker(env, nomask_fn)\n",
  206. " else:\n",
  207. " assert(False) \n",
  208. " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
  209. " model.set_logger(logger)\n",
  210. " steps = 20_000\n",
  211. "\n",
  212. " \n",
  213. " model.learn(steps,callback=[InfoCallback()])\n",
  214. "\n",
  215. "\n",
  216. "\n",
  217. "if __name__ == '__main__':\n",
  218. " main()"
  219. ]
  220. },
  221. {
  222. "cell_type": "code",
  223. "execution_count": null,
  224. "metadata": {},
  225. "outputs": [],
  226. "source": []
  227. }
  228. ],
  229. "metadata": {
  230. "kernelspec": {
  231. "display_name": "Python 3 (ipykernel)",
  232. "language": "python",
  233. "name": "python3"
  234. },
  235. "language_info": {
  236. "codemirror_mode": {
  237. "name": "ipython",
  238. "version": 3
  239. },
  240. "file_extension": ".py",
  241. "mimetype": "text/x-python",
  242. "name": "python",
  243. "nbconvert_exporter": "python",
  244. "pygments_lexer": "ipython3",
  245. "version": "3.10.12"
  246. }
  247. },
  248. "nbformat": 4,
  249. "nbformat_minor": 4
  250. }