You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
18 KiB
250 lines
18 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example usage of Tempestpy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)\n",
|
|
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2024-09-24 13:10:36.778354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
|
"2024-09-24 13:10:36.792058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
|
"2024-09-24 13:10:36.796147: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
|
"2024-09-24 13:10:36.806045: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
|
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
|
"2024-09-24 13:10:37.640233: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
|
"error: XDG_RUNTIME_DIR not set in the environment.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sb3_contrib import MaskablePPO\n",
|
|
"from sb3_contrib.common.wrappers import ActionMasker\n",
|
|
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
|
|
"\n",
|
|
"import gymnasium as gym\n",
|
|
"\n",
|
|
"from minigrid.core.actions import Actions\n",
|
|
"from minigrid.core.constants import TILE_PIXELS\n",
|
|
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
|
|
"\n",
|
|
"import tempfile, datetime, shutil\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import os\n",
|
|
"\n",
|
|
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
|
|
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
|
|
"\n",
|
|
"import os, sys\n",
|
|
"from copy import deepcopy\n",
|
|
"\n",
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCADAAMADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDDooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAPN9Qu9XbU9Q+z3V4Yop3BEcrYUbjjgHpxWf/a+p/8AQRu/+/7f4102kf8AIa1v/r4P/oT1cvNIsr7JlhAc/wAacH/6/wCNcVTGKnUcJLQ+nwnDc8ZgoYmjP3nfR7aNrf5djjf7X1P/AKCN3/3/AG/xo/tfU/8AoI3f/f8Ab/GtO88L3EWWtZBMv90/K3+BrEmhlt5Ck0bIw7MMV0060Knws8XF5disG7V4NefT79j6TorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfAxw8ZJPnS++/wB3X5fgfWyryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8AIq5p1hcQapHPq2tGe/dG8u1jYRxKvfanVsepolh4xTfOn99/u6fP8QjXk2lytfdb7+vy/A8e1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6/Qj4Uuf2vqf8A0Ebv/v8At/jR/a+p/wDQRu/+/wC3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/0Ebv8A7/t/jRq//Iav/wDr4k/9CNU6ALn9r6n/ANBG7/7/ALf40f2vqf8A0Ebv/v8At/jVOigD2SiiigAooooAKKKKACiiigDjdI/5DWt/9fB/9Cerl5q9lY5EswLj+BOT/wDW/GuS1K5mh1fUkilZFe4fcFOM4Y/41nVxVMGqlRzk9D6fCcSTweChhqMPeV9Xtq29vn3N+88UXEuVtYxCv94/M3+ArEmmluJC80jOx7sc1HRXTTowp/CjxcXmOKxjvXm35dPu2Pa9ej8O3etJbXs0lhqgUNDeITET7B+h+h/Cq1/p+m6ckc/irXJdR2/6m3cbVb0/dr94+549a628sbXUbdre8t454j1WRcis7S/C2j6PKZrS0HndpJCXZR2AJ6DtxXxFPFRjBJylp00/B7x9NT6ephpSm2ktfX8VtL8DDsNP03UUkn8K65Lp27/XW6Dcq+v7tvun3HHpVnQY/DtprT21lNJf6oVLTXjkykexfoPoPxrT1Twto+sSia7tB53eSMlGYdwSOo7c1o2dja6dbrb2dvHBEOixrgUqmKjKDSlLX0/F7y/AKeGlGabS09fwW0fxPnzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnX3h8WFFFFAFzV/+Q1f/APXxJ/6Eap1c1f8A5DV//wBfEn/oRqnQAUUUUAeyUUUUAFFFFABRRRQAUUUUAeZapqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAGvqmqahHq96iX90qLO4VVmYADceBzVT+19T/AOgjd/8Af9v8aNX/AOQ1f/8AXxJ/6Eap0AXP7X1P/oI3f/f9v8aP7X1P/oI3f/f9v8ap0UAa+qapqEer3qJf3Sos7hVWZgANx4HNVP7X1P8A6CN3/wB/2/xo1f8A5DV//wBfEn/oRqnQBc/tfU/+gjd/9/2/xo/tfU/+gjd/9/2/xqnRQBr6pqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAHslFFFABRRRQAUUUUAFFFFAHk+r/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigC5q/wDyGr//AK+JP/QjVOrmr/8AIav/APr4k/8AQjVOgAooooAuav8A8hq//wCviT/0I1Tq5q//ACGr/wD6+JP/AEI1ToAKKKKALmr/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigD2SiiigAooooAKKKKACiiigDzLVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/wBBG7/7/t/jRq//ACGr/wD6+JP/AEI1ToAuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigD6XorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfnscPGST50vvv93X5fgfdSryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8irmnWFxBqkc+ra0Z790by7WNhHEq99qdWx6miWHjFN86f33+7p8/xCNeTaXK191vv6/L8Dx7VNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6/Qj4Uuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6ALn9r6n/0Ebv/AL/t/jR/a+p/9BG7/wC/7f41TooA9kooooAKKKKACiiigAooooA8n1f/AJDV/wD9fEn/AKEap1c1f/kNX/8A18Sf+hGqdABRRRQB7Xr0fh271pLa9mksNUChobxCYifYP0P0P4VWv9P03Tkjn8Va5LqO3/U27jaren7tfvH3PHrXW3lja6jbtb3lvHPEeqyLkVnaX4W0fR5TNaWg87tJIS7KOwBPQduK+Ep4qMYJOUtOmn4PePpqfaVMNKU20lr6/itpfgYdhp+m6ikk/hXXJdO3f663QblX1/dt90+449Ks6DH4dtNae2sppL/VCpaa8cmUj2L9B9B+NaeqeFtH1iUTXdoPO7yRkozDuCR1HbmtGzsbXTrdbezt44Ih0WNcClUxUZQaUpa+n4veX4BTw0ozTaWnr+C2j+J8+av/AMhq/wD+viT/ANCNU6uav/yGr/8A6+JP/QjVOvvD4sKKKKALmr/8hq//AOviT/0I1Tq5q/8AyGr/AP6+JP8A0I1ToAKKKKAPZKKKKACiiigAooooAKKKKAPMtU1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyGr/AP6+JP8A0I1ToAuf2vqf/QRu/wDv+3+NH9r6n/0Ebv8A7/t/jVOigD6XorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfnscPGST50vvv93X5fgfdSryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8irmnWFxBqkc+ra0Z790by7WNhHEq99qdWx6miWHjFN86f33+7p8/wAQjXk2lytfdb7+vy/A8e1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/9BG7/wC/7f40av8A8hq//wCviT/0I1Tr9CPhS5/a+p/9BG7/AO/7f40f2vqf/QRu/wDv+3+NU6KAPpaiiivzQ/IAooooA42iiiv0s/XwooooAKKKKACiiigDyfV/+Q1f/wDXxJ/6Eap1c1f/AJDV/wD9fEn/AKEap0AFFFFAHtevR+HbvWktr2aSw1QKGhvEJiJ9g/Q/Q/hVa/0/TdOSOfxVrkuo7f8AU27jaren7tfvH3PHrXW3lja6jbtb3lvHPEeqyLkVnaX4W0fR5TNaWg87tJIS7KOwBPQduK+Ep4qMYJOUtOmn4PePpqfaVMNKU20lr6/itpfgYdhp+m6ikk/hXXJdO3f663QblX1/dt90+449Ks6DH4dtNae2sppL/VCpaa8cmUj2L9B9B+NaeqeFtH1iUTXdoPO7yRkozDuCR1HbmtGzsbXTrdbezt44Ih0WNcClUxUZQaUpa+n4veX4BTw0ozTaWnr+C2j+J8+av/yGr/8A6+JP/QjVOrmr/wDIav8A/r4k/wDQjVOvvD4sKKKKAPpaiiivzQ/IAooooA42iiiv0s/XwooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA//2Q==",
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAADACAIAAADdvvtQAAAKN0lEQVR4Ae2d8VUctxaH5Zz8/dIApICkATYNOA3gBoILMG7ADRgKgDQQGrAbCDSQFBBo4LmB937jS5TJzI4kVtxhV/p05pi1dHVX+uk7d3SXRfPq7OwsUFBgVwW+2bUj/VBgUACA4KBKAQCqko/OAAQDVQoAUJV8dAYgGKhSAICq5KMzAMFAlQIAVCUfnQEIBqoUAKAq+egMQDBQpQAAVclHZwCCgSoFAKhKPjoDEAxUKQBAVfLRGYBgoEoBAKqSj84ABANVCgBQlXx0BiAYqFIAgKrkozMAwUCVAgBUJR+dAQgGqhQAoCr56AxAMFClAABVyUdnAIKBKgUAqEo+On976BJcHV27TuHtA+cnpQReA6Dra8c1vvqQmt6ztLmOXwd8eft/FhGWnHALW1KG+iIFAKhIJoyWFACgJWWoL1IAgIpkwmhJgVUB+i2Ek6WBUH+YCqwK0GkItyH8HsK7wxSLUc8VWBUge/tNCBch/C+Ej88XkO7/G27v57Ob1sim0EwOKSUKvABAcVjnfwckRabK8v5z+OnXsLle5EPcqFU2hWZySClRYI0PEtPjUEDSpfBxE8JlCA9p64XWd5tw/yXcPQx8nByFi9dhc/xoKnTOPw9NKkf/Gf7NmsmDHFJKFHh5gGyUWm4FJF3CSDe4u5Kxj2yEy+3ZEH6MFcPofBMubv9BR0yoRkWVl1/r52YT+EbvwMvtCuwLQHF0up3pUkBSNBJMTwpIE4zeqP/XqBPRsXcRRsaWYWRmoBOX4Ekv9g4gG70CkuKQXcLgSQEpYiQ+To4fo85clIjR3f1ww4q3vLklNQkF9hSgOGK7ryn5F0aKSeVFQJQwMdzU2O6Uyzqz/GZWs48VWmJFo7++Zv5H2wZ4/ikcXwybm0SxRCyRqVlfOZErOaSUKLDvEahkDmbz8CW8/zzsjic7HrXGzbVZzjM1q7fNtfxQyhU4DICyt7CLn8Ppj48p2BijMTq2TZY040zNEv4xOuymy+mR5b4DpDtX4SY67p2ND2Gky8qEiUnCH/WamMV6XiQU2FOAdkvjNc8JRktMFJolhKPJFNg7gBRvFHWelLfP19L4mNdPagrNJr3471iBfcnCFHLEjT7+ebMTPcNe51PI/gZUex1d6SInciWHlBIFXj4CaUHtQ+eS4S7ZKPm6+TNc3oV3J0MWdvzd1HC8Td6aqamD0FGTnKgoF1N8omQVeEmAyjfI2Wl8fD38olRrb9cYozE62hKp6Jep40xNNWN09F/rPphScgq8AEDZnDw35i3tCjnK5BV7LIQYRqc/hNuHIZaojHfTMbc3jDZHQ/SyMibvsYofSQVWBehZNsiJ6UwwMizG6FjfSQpmZqCTEDbRtCpA2iCvUMYY6dPFpa1MxOjmj+3bphWG2sBbrArQmnoZRtl3FEZLhGX7YiAF9iWNZzEOVAEAOtCF25dhA9C+rMSBjgOADnTh9mXYr3S8yL6MhXEcoAJrZGHe59/gPwGed4DgFpYQn6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCB38+0NXRdWJ69U1vHzg/KaXiwZ8PdPUhNb1naTv084eeRYQlJ9zClpShvkgBACqSCaMlBQBoSRnqixQAoCKZMFpSAICWlKG+SAEAKpIJoyUFGgFIT4wreUilbArN5JBSokAjAOnRcT/9GjbXi3yIG7XKptBMDiklCqzxQWLJOCpt9KzC+y/DsyzFx+QJc0LHniSvt9BjMVWyZvIgh5QSBRoBKD49zlgxjM43wzOahYuK0BETqlGxp6hGjMZmE/gGa0pSgUYAsjlOMHqjR2z+Gx0zEzEGjR6wKozMDHRMnKf+2xRANvmIkfg4OX6MOnNdIkZ390Nw4rmFc4lKahoEKGJUwoQwCmx3SkhZsGkkCzv/FI4vhs1NolgilsjUrK+cyJUcUkoUaCcC6fnw7z8Pz42Pm+U4/3Eipsp5pmaWtrm258zHvrxIK9BIBLr4Ofz+y5DAG0YxGsWPf7RZVqtszMxSsBiNLOqIP3U3MzmklCjQTgSKe2fL5EWDLiuTDOv2bPi8MSb8UaaJWaznRUKBdgCySU4wWmKi0CwhHE2mQGsAjTHKrrFhlDXDIKFAI3ug4Zb0KWR/A6q9jq50kRO5kkNKiQKNRCAlXzd/hsu78O5kyMKOv5vOfZxhbc3U1EHoqElOVLSbVnyiZBVoJAJ9fD2go6Ll//7yX9FokmFpVzTJ1NTLoo46Gj1yJYeUEgUaiUAKOUq8FXsshIgDXac/hNuHAReV8W46pmBK02S/ORqil5WlAPbYzI+ZAo0AZPOaYGRYjNExs0kKZmagM2OjqKIpgOYYnf64uJWJGN38sX3bVKRf90YNAjTGKLu+wojNclalhEEjm+jEDGlyVQCAXOVt3zkAtb/GrjMEIFd523e+xib67OzMU0jf84E0ctfxX19dhytPed56Og9hDYBcz9e5+uArkLy7jt+XHndtArcwf42bfgcAanp5/ScHQP4aN/0OANT08vpPDoD8NW76HQCo6eX1n1wjAOkbYSVfQpVNoZkcUkoUaAQgfTVMfy4Y/85rPnNxw/lAc1nqa9b4ILF+lFkP+i4i5wNlVfIwaASg+O2w+OeC+iKineGiP0JV4XwgD3rksxGATJ0JRpwP5ATN2G1TAE0w0hfmOR9ovNgerxsEKGJU8l1VzgeqpKqRLIzzgSo52Ll7OxHI/lxQty1lZENcGRXl8La5tjrOBxppU/uykQjE+UC1IOzav50INEnB9NGiLiuTvy3kfKBdadnSrx2AbHITjCboRAEKzaI9L5YUaA0gm6fxsTTnWF9oFu15MVegkT3QsE3mfKD58vrXNBKBlHxxPpA/LVveoZEI9JHzgbYs7hpVjUSgycEunA+0Bjtf36MRgEyvCUacD7QCRk0BNMeI84G8GWoQoDFGWfmUyeui7KxAI5vonedPx0oFAKhSwN67A1DvBFTOH4AqBey9+yvXw296V7eD+ROBOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9p/h/BVTY14Za+YMAAAAASUVORK5CYII=",
|
|
"text/plain": [
|
|
"<PIL.Image.Image image mode=RGB size=192x192>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdin",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" \n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"Computing new shield\n",
|
|
"LOG: Starting with explicit model creation...\n",
|
|
"Elapsed time is 0.003572702407836914 seconds.\n",
|
|
"LOG: Starting with model checking...\n",
|
|
"Elapsed time is 0.0002765655517578125 seconds.\n",
|
|
"LOG: Starting to translate shield...\n",
|
|
"Write to file shielding_files_20240924T131106_4c2n0rfg/shield.\n",
|
|
"Elapsed time is 0.0036995410919189453 seconds.\n",
|
|
"\n",
|
|
"\n",
|
|
"Computing new shield\n",
|
|
"LOG: Starting with explicit model creation...\n",
|
|
"Elapsed time is 0.001905202865600586 seconds.\n",
|
|
"LOG: Starting with model checking...\n",
|
|
"Elapsed time is 0.0001571178436279297 seconds.\n",
|
|
"LOG: Starting to translate shield...\n",
|
|
"Elapsed time is 0.0028374195098876953 seconds.\n",
|
|
"Symbolic Description of the Model:\n",
|
|
"Write to file shielding_files_20240924T131106_mkizoz41/shield.\n",
|
|
"mdp\n",
|
|
"\n",
|
|
"formula AgentCannotMoveEastWall = (colAgent=4&rowAgent=1) | (colAgent=4&rowAgent=2) | (colAgent=4&rowAgent=3) | (colAgent=4&rowAgent=4);\n",
|
|
"formula AgentCannotMoveNorthWall = (colAgent=3&rowAgent=1) | (colAgent=4&rowAgent=1) | (colAgent=1&rowAgent=1) | (colAgent=2&rowAgent=1);\n",
|
|
"formula AgentCannotMoveSouthWall = (colAgent=1&rowAgent=4) | (colAgent=3&rowAgent=4) | (colAgent=4&rowAgent=4) | (colAgent=2&rowAgent=4);\n",
|
|
"formula AgentCannotMoveWestWall = (colAgent=1&rowAgent=2) | (colAgent=1&rowAgent=3) | (colAgent=1&rowAgent=4) | (colAgent=1&rowAgent=1);\n",
|
|
"formula AgentIsOnSlippery = false;\n",
|
|
"formula AgentIsOnLava = (colAgent=2&rowAgent=1) | (colAgent=2&rowAgent=3) | (colAgent=2&rowAgent=4);\n",
|
|
"formula AgentIsOnGoal = (colAgent=4&rowAgent=4);\n",
|
|
"init\n",
|
|
" true\n",
|
|
"endinit\n",
|
|
"\n",
|
|
"\n",
|
|
"module Agent\n",
|
|
" colAgent : [1..4];\n",
|
|
" rowAgent : [1..4];\n",
|
|
" viewAgent : [0..3];\n",
|
|
"\n",
|
|
" [Agent_turn_right] !AgentIsOnLava &true -> 1.000000: (viewAgent'=mod(viewAgent+1,4));\n",
|
|
" [Agent_turn_left] !AgentIsOnLava &viewAgent>0 -> 1.000000: (viewAgent'=viewAgent-1);\n",
|
|
" [Agent_turn_left] !AgentIsOnLava &viewAgent=0 -> 1.000000: (viewAgent'=3);\n",
|
|
" [Agent_move_North] viewAgent=3 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveNorthWall -> 1.000000: (rowAgent'=rowAgent-1);\n",
|
|
" [Agent_move_East] viewAgent=0 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveEastWall -> 1.000000: (colAgent'=colAgent+1);\n",
|
|
" [Agent_move_South] viewAgent=1 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveSouthWall -> 1.000000: (rowAgent'=rowAgent+1);\n",
|
|
" [Agent_move_West] viewAgent=2 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveWestWall -> 1.000000: (colAgent'=colAgent-1);\n",
|
|
"endmodule\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
|
|
"\n",
|
|
"def mask_fn(env: gym.Env):\n",
|
|
" return env.create_action_mask()\n",
|
|
"\n",
|
|
"def nomask_fn(env: gym.Env):\n",
|
|
" return [1.0] * 7\n",
|
|
"\n",
|
|
"def main():\n",
|
|
" env = \"MiniGrid-LavaGapS6-v0\"\n",
|
|
"\n",
|
|
" # TODO Change the safety specification\n",
|
|
" formula = \"Pmax=? [G !AgentIsOnLava]\"\n",
|
|
" value_for_training = 1.0\n",
|
|
" shield_comparison = \"absolute\"\n",
|
|
" shielding = ShieldingConfig.Training\n",
|
|
" \n",
|
|
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
|
|
" \n",
|
|
"\n",
|
|
" env = gym.make(env, render_mode=\"rgb_array\")\n",
|
|
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
|
|
" env = RGBImgObsWrapper(env, 8)\n",
|
|
" env = ImgObsWrapper(env)\n",
|
|
" env = MiniWrapper(env)\n",
|
|
"\n",
|
|
" \n",
|
|
" env.reset()\n",
|
|
" Image.fromarray(env.render()).show()\n",
|
|
" input(\"\") \n",
|
|
" \n",
|
|
" shield_handlers = dict()\n",
|
|
" if shield_needed(shielding):\n",
|
|
" for value in [0.0, 1.0]:\n",
|
|
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
|
|
" shield_handlers[value] = shield_handler\n",
|
|
"\n",
|
|
" print(\"Symbolic Description of the Model:\")\n",
|
|
" shield_handlers[1.0].print_symbolic_model()\n",
|
|
" input(\"\")\n",
|
|
"\n",
|
|
" if shield_needed(shielding):\n",
|
|
" for value in [1.0]:\n",
|
|
" create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
|
|
" print(f\"The shield for shield_value = {value}\")\n",
|
|
" input(\"\")\n",
|
|
" \n",
|
|
" if shielding == ShieldingConfig.Training:\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
|
|
" env = ActionMasker(env, mask_fn)\n",
|
|
" print(\"Training with shield:\")\n",
|
|
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
|
|
" elif shielding == ShieldingConfig.Disabled:\n",
|
|
" env = ActionMasker(env, nomask_fn)\n",
|
|
" else:\n",
|
|
" assert(False) \n",
|
|
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
|
|
" model.set_logger(logger)\n",
|
|
" steps = 20_000\n",
|
|
"\n",
|
|
" \n",
|
|
" model.learn(steps,callback=[InfoCallback()])\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == '__main__':\n",
|
|
" main()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|