The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.
 
 
 
 
 
 

719 lines
324 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example usage of Tempestpy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-09-24 13:00:24.842265: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-09-24 13:00:24.857294: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-09-24 13:00:24.861712: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-09-24 13:00:24.871856: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-09-24 13:00:25.712381: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
"error: XDG_RUNTIME_DIR not set in the environment.\n"
]
}
],
"source": [
"from sb3_contrib import MaskablePPO\n",
"from sb3_contrib.common.wrappers import ActionMasker\n",
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
"\n",
"import gymnasium as gym\n",
"\n",
"from minigrid.core.actions import Actions\n",
"from minigrid.core.constants import TILE_PIXELS\n",
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
"\n",
"import tempfile, datetime, shutil\n",
"\n",
"import time\n",
"import os\n",
"\n",
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
"\n",
"import os, sys\n",
"from copy import deepcopy\n",
"\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Starting the training\n"
]
},
{
"data": {
"image/jpeg": "",
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGB size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"Computing new shield\n",
"LOG: Starting with explicit model creation...\n",
"Elapsed time is 0.053414344787597656 seconds.\n",
"LOG: Starting with model checking...\n",
"Elapsed time is 0.0018138885498046875 seconds.\n",
"LOG: Starting to translate shield...\n",
"Elapsed time is 0.08892679214477539 seconds.\n",
"\n",
"\n",
"Computing new shield\n",
"LOG: Starting with explicit model creation...\n",
"Elapsed time is 0.04616570472717285 seconds.\n",
"LOG: Starting with model checking...\n",
"Elapsed time is 0.0017552375793457031 seconds.\n",
"LOG: Starting to translate shield...\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1211\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1222\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1233\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1244\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1255\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1266\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1277\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1288\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1299\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1344\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1355\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1366\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1377\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1388\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1399\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1410\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1421\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1432\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1460\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1471\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1482\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1493\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1504\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1515\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1526\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1537\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1604\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1615\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1626\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1637\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1648\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1659\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1670\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1681\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4689\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4711\n",
"Elapsed time is 0.0928337574005127 seconds.\n",
"\n",
"\n",
"Computing new shield\n",
"LOG: Starting with explicit model creation...\n",
"Elapsed time is 0.04615354537963867 seconds.\n",
"LOG: Starting with model checking...\n",
"Elapsed time is 0.0014677047729492188 seconds.\n",
"LOG: Starting to translate shield...\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1211\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1222\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1233\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1244\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1255\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1266\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1277\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1288\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1299\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1344\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1355\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1366\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1377\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1388\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1399\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1410\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1421\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1432\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1460\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1471\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1482\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1493\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1504\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1515\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1526\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1537\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1604\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1615\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1626\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1637\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1648\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1659\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1670\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1681\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4689\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4711\n",
"Elapsed time is 0.08211565017700195 seconds.\n",
"\n",
"\n",
"Computing new shield\n",
"LOG: Starting with explicit model creation...\n",
"Elapsed time is 0.045874834060668945 seconds.\n",
"LOG: Starting with model checking...\n",
"Elapsed time is 0.0018587112426757812 seconds.\n",
"LOG: Starting to translate shield...\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1210\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1211\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1221\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1222\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1232\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1233\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1243\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1244\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1254\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1255\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1265\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1266\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1276\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1277\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1287\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1288\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1298\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1299\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1332\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1343\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1344\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1354\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1355\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1365\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1366\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1376\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1377\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1387\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1388\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1398\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1399\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1409\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1410\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1420\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1421\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1432\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1460\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1461\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1471\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1472\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1482\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1483\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1493\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1494\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1504\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1505\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1515\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1516\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1526\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1527\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1537\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1538\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1604\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1615\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1616\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1626\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1627\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1637\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1638\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1648\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1649\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1659\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1660\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1670\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1671\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1681\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1682\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1693\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4689\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4695\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4711\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4718\n",
"Elapsed time is 0.0849301815032959 seconds.\n",
"\n",
"\n",
"Computing new shield\n",
"LOG: Starting with explicit model creation...\n",
"Elapsed time is 0.04583621025085449 seconds.\n",
"LOG: Starting with model checking...\n",
"Elapsed time is 0.0014719963073730469 seconds.\n",
"LOG: Starting to translate shield...\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1205\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1206\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1207\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1208\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1209\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1210\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1211\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1216\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1217\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1218\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1219\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1220\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1221\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1222\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1227\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1228\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1229\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1230\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1231\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1232\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1233\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1238\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1239\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1240\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1241\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1242\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1243\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1244\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1249\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1250\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1251\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1252\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1253\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1254\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1255\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1261\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1262\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1263\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1264\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1265\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1266\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1273\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1274\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1275\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1276\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1277\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1285\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1286\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1287\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1288\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1297\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1298\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1299\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1316\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1317\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1318\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1319\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1320\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1327\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1328\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1329\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1330\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1331\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1332\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1338\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1339\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1340\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1341\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1342\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1343\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1344\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1349\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1350\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1351\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1352\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1353\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1354\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1355\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1360\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1361\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1362\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1363\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1364\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1365\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1366\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1372\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1373\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1374\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1375\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1376\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1377\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1384\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1385\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1386\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1387\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1388\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1396\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1397\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1398\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1399\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1408\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1409\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1410\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1420\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1421\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1432\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1460\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1461\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1462\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1463\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1464\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1471\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1472\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1473\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1474\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1475\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1476\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1482\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1483\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1484\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1485\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1486\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1487\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1488\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1493\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1494\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1495\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1496\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1497\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1498\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1499\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1504\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1505\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1506\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1507\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1508\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1509\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1510\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1515\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1516\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1517\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1518\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1519\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1520\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1521\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1526\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1527\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1528\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1529\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1530\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1531\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1532\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1537\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1538\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1539\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1540\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1541\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1542\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1543\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1604\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1615\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1616\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1626\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1627\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1628\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1637\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1638\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1639\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1640\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1648\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1649\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1650\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1651\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1652\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1659\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1660\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1661\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1662\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1663\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1664\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1670\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1671\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1672\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1673\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1674\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1675\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1676\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1681\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1682\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1683\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1684\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1685\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1686\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1687\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1693\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1694\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1695\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1696\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1697\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1698\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1705\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1706\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1707\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1708\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1709\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1717\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1718\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1719\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1720\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1729\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1730\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 1731\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4689\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4695\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4711\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4718\n",
" WARN (PreShield.cpp:50): No shielding action possible with absolute comparison for state with index 4725\n",
"Elapsed time is 0.08353805541992188 seconds.\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shield for shield_value = 0.0\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shield for shield_value = 0.9\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shield for shield_value = 0.99\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shield for shield_value = 0.999\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The shield for shield_value = 1.0\n",
"Training with shield:\n"
]
},
{
"data": {
"image/png": "",
"text/plain": [
"<PIL.Image.Image image mode=RGBA size=480x480>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cpu device\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n",
"Wrapping the env in a VecTransposeImage.\n"
]
},
{
"ename": "AssertionError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 63\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStarting the training\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 63\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[3], line 56\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 53\u001b[0m model\u001b[38;5;241m.\u001b[39mset_logger(logger)\n\u001b[1;32m 54\u001b[0m steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m20_000\u001b[39m\n\u001b[0;32m---> 56\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m(\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 57\u001b[0m model\u001b[38;5;241m.\u001b[39mlearn(steps,callback\u001b[38;5;241m=\u001b[39m[InfoCallback()])\n",
"\u001b[0;31mAssertionError\u001b[0m: "
]
}
],
"source": [
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
"\n",
"def mask_fn(env: gym.Env):\n",
" return env.create_action_mask()\n",
"\n",
"def nomask_fn(env: gym.Env):\n",
" return [1.0] * 7\n",
"\n",
"def main():\n",
" env = \"MiniGrid-LavaFaultyS15-1-v0\"\n",
" \n",
" formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
" value_for_training = 0.0\n",
" shield_comparison = \"absolute\"\n",
" shielding = ShieldingConfig.Training\n",
" \n",
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
" \n",
" env = gym.make(env, render_mode=\"rgb_array\")\n",
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
" env = RGBImgObsWrapper(env, 8)\n",
" env = ImgObsWrapper(env)\n",
" env = MiniWrapper(env)\n",
"\n",
" \n",
" env.reset()\n",
" Image.fromarray(env.render()).show()\n",
"\n",
" shield_values = [0.0, 0.9, 0.99, 0.999, 1.0]\n",
" shield_handlers = dict()\n",
" if shield_needed(shielding):\n",
" for value in shield_values: \n",
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)\n",
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
" shield_handlers[value] = shield_handler\n",
"\n",
" if shield_needed(shielding):\n",
" for value in shield_values: \n",
" create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
" print(f\"The shield for shield_value = {value}\")\n",
"\n",
"\n",
" if shielding == ShieldingConfig.Training:\n",
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
" env = ActionMasker(env, mask_fn)\n",
" print(\"Training with shield:\")\n",
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
" elif shielding == ShieldingConfig.Disabled:\n",
" env = ActionMasker(env, nomask_fn)\n",
" else:\n",
" assert(False) \n",
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
" model.set_logger(logger)\n",
" steps = 20_000\n",
"\n",
" assert(False)\n",
" model.learn(steps,callback=[InfoCallback()])\n",
"\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" print(\"Starting the training\")\n",
" main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}