|
|
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example usage of Tempestpy" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)\n", "Hello from the pygame community. https://www.pygame.org/contribute.html\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-09-24 13:10:36.778354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-09-24 13:10:36.792058: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-09-24 13:10:36.796147: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", "2024-09-24 13:10:36.806045: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2024-09-24 13:10:37.640233: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", "error: XDG_RUNTIME_DIR not set in the environment.\n" ] } ], "source": [ "from sb3_contrib import MaskablePPO\n", "from sb3_contrib.common.wrappers import ActionMasker\n", "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", "\n", "import gymnasium as gym\n", "\n", "from minigrid.core.actions import Actions\n", "from minigrid.core.constants import TILE_PIXELS\n", "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", "\n", "import tempfile, datetime, shutil\n", "\n", "import time\n", "import os\n", "\n", "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", "\n", "import os, sys\n", "from copy import deepcopy\n", "\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCADAAMADASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwDDooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKAPN9Qu9XbU9Q+z3V4Yop3BEcrYUbjjgHpxWf/a+p/8AQRu/+/7f4102kf8AIa1v/r4P/oT1cvNIsr7JlhAc/wAacH/6/wCNcVTGKnUcJLQ+nwnDc8ZgoYmjP3nfR7aNrf5djjf7X1P/AKCN3/3/AG/xo/tfU/8AoI3f/f8Ab/GtO88L3EWWtZBMv90/K3+BrEmhlt5Ck0bIw7MMV0060Knws8XF5disG7V4NefT79j6TorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfAxw8ZJPnS++/wB3X5fgfWyryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8AIq5p1hcQapHPq2tGe/dG8u1jYRxKvfanVsepolh4xTfOn99/u6fP8QjXk2lytfdb7+vy/A8e1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6/Qj4Uuf2vqf8A0Ebv/v8At/jR/a+p/wDQRu/+/wC3+NU6KANfVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/0Ebv8A7/t/jRq//Iav/wDr4k/9CNU6ALn9r6n/ANBG7/7/ALf40f2vqf8A0Ebv/v8At/jVOigD2SiiigAooooAKKKKACiiigDjdI/5DWt/9fB/9Cerl5q9lY5EswLj+BOT/wDW/GuS1K5mh1fUkilZFe4fcFOM4Y/41nVxVMGqlRzk9D6fCcSTweChhqMPeV9Xtq29vn3N+88UXEuVtYxCv94/M3+ArEmmluJC80jOx7sc1HRXTTowp/CjxcXmOKxjvXm35dPu2Pa9ej8O3etJbXs0lhqgUNDeITET7B+h+h/Cq1/p+m6ckc/irXJdR2/6m3cbVb0/dr94+549a628sbXUbdre8t454j1WRcis7S/C2j6PKZrS0HndpJCXZR2AJ6DtxXxFPFRjBJylp00/B7x9NT6ephpSm2ktfX8VtL8DDsNP03UUkn8K65Lp27/XW6Dcq+v7tvun3HHpVnQY/DtprT21lNJf6oVLTXjkykexfoPoPxrT1Twto+sSia7tB53eSMlGYdwSOo7c1o2dja6dbrb2dvHBEOixrgUqmKjKDSlLX0/F7y/AKeGlGabS09fwW0fxPnzV/wDkNX//AF8Sf+hGqdXNX/5DV/8A9fEn/oRqnX3h8WFFFFAFzV/+Q1f/APXxJ/6Eap1c1f8A5DV//wBfEn/oRqnQAUUUUAeyUUUUAFFFFABRRRQAUUUUAeZapqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAGvqmqahHq96iX90qLO4VVmYADceBzVT+19T/AOgjd/8Af9v8aNX/AOQ1f/8AXxJ/6Eap0AXP7X1P/oI3f/f9v8aP7X1P/oI3f/f9v8ap0UAa+qapqEer3qJf3Sos7hVWZgANx4HNVP7X1P8A6CN3/wB/2/xo1f8A5DV//wBfEn/oRqnQBc/tfU/+gjd/9/2/xo/tfU/+gjd/9/2/xqnRQBr6pqmoR6veol/dKizuFVZmAA3Hgc1U/tfU/wDoI3f/AH/b/GjV/wDkNX//AF8Sf+hGqdAFz+19T/6CN3/3/b/Gj+19T/6CN3/3/b/GqdFAHslFFFABRRRQAUUUUAFFFFAHk+r/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigC5q/wDyGr//AK+JP/QjVOrmr/8AIav/APr4k/8AQjVOgAooooAuav8A8hq//wCviT/0I1Tq5q//ACGr/wD6+JP/AEI1ToAKKKKALmr/APIav/8Ar4k/9CNU6uav/wAhq/8A+viT/wBCNU6ACiiigD2SiiigAooooAKKKKACiiigDzLVNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/wBBG7/7/t/jRq//ACGr/wD6+JP/AEI1ToAuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigD6XorA1GwuJ9Ukn0nWjBfoi+ZayMJImXtuTqufUVTltPEWpox1a/g0mxQfvBZv88gHU7z90f5NfnscPGST50vvv93X5fgfdSryTa5W/ut9/T5/idXRXKRWniLTEU6TfwatYuP3YvH+eMHod4+8P8irmnWFxBqkc+ra0Z790by7WNhHEq99qdWx6miWHjFN86f33+7p8/xCNeTaXK191vv6/L8Dx7VNU1CPV71Ev7pUWdwqrMwAG48Dmqn9r6n/ANBG7/7/ALf40av/AMhq/wD+viT/ANCNU6/Qj4Uuf2vqf/QRu/8Av+3+NH9r6n/0Ebv/AL/t/jVOigDX1TVNQj1e9RL+6VFncKqzMABuPA5qp/a+p/8AQRu/+/7f40av/wAhq/8A+viT/wBCNU6ALn9r6n/0Ebv/AL/t/jR/a+p/9BG7/wC/7f41TooA9kooooAKKKKACiiigAooooA8n1f/AJDV/wD9fEn/AKEap1c1f/kNX/8A18Sf+hGqdABRRRQB7Xr0fh271pLa9mksNUChobxCYifYP0P0P4VWv9P03Tkjn8Va5LqO3/U27jaren7tfvH3PHrXW3lja6jbtb3lvHPEeqyLkVnaX4W0fR5TNaWg87tJIS7KOwBPQduK+Ep4qMYJOUtOmn4PePpqfaVMNKU20lr6/itpfgYdhp+m6ikk/hXXJdO3f663QblX1/dt90+449Ks6DH4dtNae2sppL/VCpaa8cmUj2L9B9B+NaeqeFtH1iUTXdoPO7yRkozDuCR1HbmtGzsbXTrdbezt44Ih0WNcClUxUZQaUpa+n4veX4BTw0ozTaWnr+C2j+J8+av/AMhq/wD+viT/ANCNU6uav/yGr/8A6+JP/QjVOvvD4sKKKKALmr/8hq//AOviT/0I1Tq5q/8AyGr/AP6+JP8A0I1ToAKKKKAPZKKKKACiiigAooooAKKKKAPMtU1TUI9XvUS/ulRZ3CqszAAbjwOaqf2vqf8A0Ebv/v8At/jRq/8AyG "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMAAAADACAIAAADdvvtQAAAKN0lEQVR4Ae2d8VUctxaH5Zz8/dIApICkATYNOA3gBoILMG7ADRgKgDQQGrAbCDSQFBBo4LmB937jS5TJzI4kVtxhV/p05pi1dHVX+uk7d3SXRfPq7OwsUFBgVwW+2bUj/VBgUACA4KBKAQCqko/OAAQDVQoAUJV8dAYgGKhSAICq5KMzAMFAlQIAVCUfnQEIBqoUAKAq+egMQDBQpQAAVclHZwCCgSoFAKhKPjoDEAxUKQBAVfLRGYBgoEoBAKqSj84ABANVCgBQlXx0BiAYqFIAgKrkozMAwUCVAgBUJR+dAQgGqhQAoCr56AxAMFClAABVyUdnAIKBKgUAqEo+On976BJcHV27TuHtA+cnpQReA6Dra8c1vvqQmt6ztLmOXwd8eft/FhGWnHALW1KG+iIFAKhIJoyWFACgJWWoL1IAgIpkwmhJgVUB+i2Ek6WBUH+YCqwK0GkItyH8HsK7wxSLUc8VWBUge/tNCBch/C+Ej88XkO7/G27v57Ob1sim0EwOKSUKvABAcVjnfwckRabK8v5z+OnXsLle5EPcqFU2hWZySClRYI0PEtPjUEDSpfBxE8JlCA9p64XWd5tw/yXcPQx8nByFi9dhc/xoKnTOPw9NKkf/Gf7NmsmDHFJKFHh5gGyUWm4FJF3CSDe4u5Kxj2yEy+3ZEH6MFcPofBMubv9BR0yoRkWVl1/r52YT+EbvwMvtCuwLQHF0up3pUkBSNBJMTwpIE4zeqP/XqBPRsXcRRsaWYWRmoBOX4Ekv9g4gG70CkuKQXcLgSQEpYiQ+To4fo85clIjR3f1ww4q3vLklNQkF9hSgOGK7ryn5F0aKSeVFQJQwMdzU2O6Uyzqz/GZWs48VWmJFo7++Zv5H2wZ4/ikcXwybm0SxRCyRqVlfOZErOaSUKLDvEahkDmbz8CW8/zzsjic7HrXGzbVZzjM1q7fNtfxQyhU4DICyt7CLn8Ppj48p2BijMTq2TZY040zNEv4xOuymy+mR5b4DpDtX4SY67p2ND2Gky8qEiUnCH/WamMV6XiQU2FOAdkvjNc8JRktMFJolhKPJFNg7gBRvFHWelLfP19L4mNdPagrNJr3471iBfcnCFHLEjT7+ebMTPcNe51PI/gZUex1d6SInciWHlBIFXj4CaUHtQ+eS4S7ZKPm6+TNc3oV3J0MWdvzd1HC8Td6aqamD0FGTnKgoF1N8omQVeEmAyjfI2Wl8fD38olRrb9cYozE62hKp6Jep40xNNWN09F/rPphScgq8AEDZnDw35i3tCjnK5BV7LIQYRqc/hNuHIZaojHfTMbc3jDZHQ/SyMibvsYofSQVWBehZNsiJ6UwwMizG6FjfSQpmZqCTEDbRtCpA2iCvUMYY6dPFpa1MxOjmj+3bphWG2sBbrArQmnoZRtl3FEZLhGX7YiAF9iWNZzEOVAEAOtCF25dhA9C+rMSBjgOADnTh9mXYr3S8yL6MhXEcoAJrZGHe59/gPwGed4DgFpYQn6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCAJQQh6a8AgCU1wiLhAIAlBCHprwCAJTXCIuEAgCUEIemvAIAlNcIi4QCB38+0NXRdWJ69U1vHzg/KaXiwZ8PdPUhNb1naTv084eeRYQlJ9zClpShvkgBACqSCaMlBQBoSRnqixQAoCKZMFpSAICWlKG+SAEAKpIJoyUFGgFIT4wreUilbArN5JBSokAjAOnRcT/9GjbXi3yIG7XKptBMDiklCqzxQWLJOCpt9KzC+y/DsyzFx+QJc0LHniSvt9BjMVWyZvIgh5QSBRoBKD49zlgxjM43wzOahYuK0BETqlGxp6hGjMZmE/gGa0pSgUYAsjlOMHqjR2z+Gx0zEzEGjR6wKozMDHRMnKf+2xRANvmIkfg4OX6MOnNdIkZ390Nw4rmFc4lKahoEKGJUwoQwCmx3SkhZsGkkCzv/FI4vhs1NolgilsjUrK+cyJUcUkoUaCcC6fnw7z8Pz42Pm+U4/3Eipsp5pmaWtrm258zHvrxIK9BIBLr4Ofz+y5DAG0YxGsWPf7RZVqtszMxSsBiNLOqIP3U3MzmklCjQTgSKe2fL5EWDLiuTDOv2bPi8MSb8UaaJWaznRUKBdgCySU4wWmKi0CwhHE2mQGsAjTHKrrFhlDXDIKFAI3ug4Zb0KWR/A6q9jq50kRO5kkNKiQKNRCAlXzd/hsu78O5kyMKOv5vOfZxhbc3U1EHoqElOVLSbVnyiZBVoJAJ9fD2go6Ll//7yX9FokmFpVzTJ1NTLoo46Gj1yJYeUEgUaiUAKOUq8FXsshIgDXac/hNuHAReV8W46pmBK02S/ORqil5WlAPbYzI+ZAo0AZPOaYGRYjNExs0kKZmagM2OjqKIpgOYYnf64uJWJGN38sX3bVKRf90YNAjTGKLu+wojNclalhEEjm+jEDGlyVQCAXOVt3zkAtb/GrjMEIFd523e+xib67OzMU0jf84E0ctfxX19dhytPed56Og9hDYBcz9e5+uArkLy7jt+XHndtArcwf42bfgcAanp5/ScHQP4aN/0OANT08vpPDoD8NW76HQCo6eX1n1wjAOkbYSVfQpVNoZkcUkoUaAQgfTVMfy4Y/85rPnNxw/lAc1nqa9b4ILF+lFkP+i4i5wNlVfIwaASg+O2w+OeC+iKineGiP0JV4XwgD3rksxGATJ0JRpwP5ATN2G1TAE0w0hfmOR9ovNgerxsEKGJU8l1VzgeqpKqRLIzzgSo52Ll7OxHI/lxQty1lZENcGRXl8La5tjrOBxppU/uykQjE+UC1IOzav50INEnB9NGiLiuTvy3kfKBdadnSrx2AbHITjCboRAEKzaI9L5YUaA0gm6fxsTTnWF9oFu15MVegkT3QsE3mfKD58vrXNBKBlHxxPpA/LVveoZEI9JHzgbYs7hpVjUSgycEunA+0Bjtf36MRgEyvCUacD7QCRk0BNMeI84G8GWoQoDFGWfmUyeui7KxAI5vonedPx0oFAKhSwN67A1DvBFTOH4AqBey9+yvXw296V7eD+ROBOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9pwhAnup24BuAOlhkzykCkKe6HfgGoA4W2XOKAOSpbge+AaiDRfacIgB5qtuBbwDqYJE9p/h/BVTY14Za+YMAAAAASUVORK5CYII=", "text/plain": [ "<PIL.Image.Image image mode=RGB size=192x192>" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdin", "output_type": "stream", "text": [ " \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "Computing new shield\n", "LOG: Starting with explicit model creation...\n", "Elapsed time is 0.003572702407836914 seconds.\n", "LOG: Starting with model checking...\n", "Elapsed time is 0.0002765655517578125 seconds.\n", "LOG: Starting to translate shield...\n", "Write to file shielding_files_20240924T131106_4c2n0rfg/shield.\n", "Elapsed time is 0.0036995410919189453 seconds.\n", "\n", "\n", "Computing new shield\n", "LOG: Starting with explicit model creation...\n", "Elapsed time is 0.001905202865600586 seconds.\n", "LOG: Starting with model checking...\n", "Elapsed time is 0.0001571178436279297 seconds.\n", "LOG: Starting to translate shield...\n", "Elapsed time is 0.0028374195098876953 seconds.\n", "Symbolic Description of the Model:\n", "Write to file shielding_files_20240924T131106_mkizoz41/shield.\n", "mdp\n", "\n", "formula AgentCannotMoveEastWall = (colAgent=4&rowAgent=1) | (colAgent=4&rowAgent=2) | (colAgent=4&rowAgent=3) | (colAgent=4&rowAgent=4);\n", "formula AgentCannotMoveNorthWall = (colAgent=3&rowAgent=1) | (colAgent=4&rowAgent=1) | (colAgent=1&rowAgent=1) | (colAgent=2&rowAgent=1);\n", "formula AgentCannotMoveSouthWall = (colAgent=1&rowAgent=4) | (colAgent=3&rowAgent=4) | (colAgent=4&rowAgent=4) | (colAgent=2&rowAgent=4);\n", "formula AgentCannotMoveWestWall = (colAgent=1&rowAgent=2) | (colAgent=1&rowAgent=3) | (colAgent=1&rowAgent=4) | (colAgent=1&rowAgent=1);\n", "formula AgentIsOnSlippery = false;\n", "formula AgentIsOnLava = (colAgent=2&rowAgent=1) | (colAgent=2&rowAgent=3) | (colAgent=2&rowAgent=4);\n", "formula AgentIsOnGoal = (colAgent=4&rowAgent=4);\n", "init\n", " true\n", "endinit\n", "\n", "\n", "module Agent\n", " colAgent : [1..4];\n", " rowAgent : [1..4];\n", " viewAgent : [0..3];\n", "\n", " [Agent_turn_right] !AgentIsOnLava &true -> 1.000000: (viewAgent'=mod(viewAgent+1,4));\n", " [Agent_turn_left] !AgentIsOnLava &viewAgent>0 -> 1.000000: (viewAgent'=viewAgent-1);\n", " [Agent_turn_left] !AgentIsOnLava &viewAgent=0 -> 1.000000: (viewAgent'=3);\n", " [Agent_move_North] viewAgent=3 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveNorthWall -> 1.000000: (rowAgent'=rowAgent-1);\n", " [Agent_move_East] viewAgent=0 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveEastWall -> 1.000000: (colAgent'=colAgent+1);\n", " [Agent_move_South] viewAgent=1 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveSouthWall -> 1.000000: (rowAgent'=rowAgent+1);\n", " [Agent_move_West] viewAgent=2 & !AgentIsOnLava & !AgentIsOnGoal & !AgentCannotMoveWestWall -> 1.000000: (colAgent'=colAgent-1);\n", "endmodule\n", "\n", "\n" ] } ], "source": [ "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", "\n", "def mask_fn(env: gym.Env):\n", " return env.create_action_mask()\n", "\n", "def nomask_fn(env: gym.Env):\n", " return [1.0] * 7\n", "\n", "def main():\n", " env = \"MiniGrid-LavaGapS6-v0\"\n", "\n", " # TODO Change the safety specification\n", " formula = \"Pmax=? [G !AgentIsOnLava]\"\n", " value_for_training = 1.0\n", " shield_comparison = \"absolute\"\n", " shielding = ShieldingConfig.Training\n", " \n", " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", " \n", "\n", " env = gym.make(env, render_mode=\"rgb_array\")\n", " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", " env = RGBImgObsWrapper(env, 8)\n", " env = ImgObsWrapper(env)\n", " env = MiniWrapper(env)\n", "\n", " \n", " env.reset()\n", " Image.fromarray(env.render()).show()\n", " input(\"\") \n", " \n", " shield_handlers = dict()\n", " if shield_needed(shielding):\n", " for value in [0.0, 1.0]:\n", " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", " shield_handlers[value] = shield_handler\n", "\n", " print(\"Symbolic Description of the Model:\")\n", " shield_handlers[1.0].print_symbolic_model()\n", " input(\"\")\n", "\n", " if shield_needed(shielding):\n", " for value in [1.0]:\n", " create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n", " print(f\"The shield for shield_value = {value}\")\n", " input(\"\")\n", " \n", " if shielding == ShieldingConfig.Training:\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", " env = ActionMasker(env, mask_fn)\n", " print(\"Training with shield:\")\n", " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", " elif shielding == ShieldingConfig.Disabled:\n", " env = ActionMasker(env, nomask_fn)\n", " else:\n", " assert(False) \n", " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", " model.set_logger(logger)\n", " steps = 20_000\n", "\n", " \n", " model.learn(steps,callback=[InfoCallback()])\n", "\n", "\n", "\n", "if __name__ == '__main__':\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }
|