You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
199 lines
28 KiB
199 lines
28 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example usage of Tempestpy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)\n",
|
|
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"2024-11-29 12:18:59.459471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
|
"2024-11-29 12:18:59.474489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
|
"2024-11-29 12:18:59.478811: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
|
|
"2024-11-29 12:18:59.488641: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
|
|
"To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
|
"2024-11-29 12:19:00.368388: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
|
|
"error: XDG_RUNTIME_DIR not set in the environment.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from sb3_contrib import MaskablePPO\n",
|
|
"from sb3_contrib.common.wrappers import ActionMasker\n",
|
|
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
|
|
"\n",
|
|
"import gymnasium as gym\n",
|
|
"\n",
|
|
"from minigrid.core.actions import Actions\n",
|
|
"from minigrid.core.constants import TILE_PIXELS\n",
|
|
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
|
|
"\n",
|
|
"import tempfile, datetime, shutil\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import os\n",
|
|
"\n",
|
|
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
|
|
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
|
|
"\n",
|
|
"import os, sys\n",
|
|
"from copy import deepcopy\n",
|
|
"\n",
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Starting the training\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"image/jpeg": "",
|
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWAAAAEgCAIAAAAFWz53AAAR8ElEQVR4Ae2d4XnUSBKGxzz8vksAXwC7CeBN4DYBNoLZADAJkMCZAJiLgAQggTMJ7AWwJoEjgb1PFm6Klrqnx6WR5NKrxw+0uqtKU2/1fJZa1szFfr/fsUEAAhAYI/BsrJM+CEAAAh0BBIJ5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgAACwRyAAASKBBCIIhoGIAABBII5AAEIFAkgEEU0DEAAAggEcwACECgSQCCKaBiAAAQQCOYABCBQJIBAFNEwAAEIIBDMAQhAoEgAgSiiYQACEEAgmAMQgECRAAJRRMMABCCAQDAHIACBIgEEooiGAQhAAIFgDkAAAkUCCEQRDQMQgMBzEMQmcHh/OGuC+9/3Z41P8GUJzCEQh8MZ5+h+vyd+bQ69rw1OMgb/CsYZ5mfl6P4hLjH8DIkAgbAEEIiwpSUxCPgJIBB+hkSAQFgCCETY0pIYBPwEEAg/QyJAICwBBCJsaUkMAn4CCISfIREgEJYAAhG2tCQGAT8BBMLPkAgQCEsAgQhbWhKDgJ8AAuFnSAQIhCWAQIQtLYlBwE8AgfAzJAIEwhJAIMKWlsQg4CeAQPgZEgECYQkgEGFLS2IQ8BNAIPwMiQCBsAQQiLClJTEI+AkgEH6GRIBAWAIIRNjSkhgE/AQQCD9DIkAgLAEEImxpSQwCfgIIhJ8hESAQlgACEba0JAYBPwEEws+QCBAISwCBCFtaEoOAnwAC4WdIBAiEJYBAhC0tiUHATwCB8DMkAgTCEkAgwpaWxCDgJ3Chryf3RyECBCAQksDzGbI6HA7nO4oEjvgVvPCpwNFQAD71BJ2jXGI4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJMAAuEEiDsEIhNAICJXl9wg4CSAQDgB4g6ByAQQiMjVJTcIOAkgEE6AuEMgMgEEInJ1yQ0CTgIIhBMg7hCITACBiFxdcoOAkwAC4QSIOwQiE0AgIleX3CDgJIBAOAHiDoHIBBCIyNUlNwg4CSAQToC4QyAyAQQicnXJDQJOAgiEEyDuEIhMAIGIXF1yg4CTAALhBIg7BCITQCAiV5fcIOAkgEA4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJPAhb7+3Bkitvv7F4fYCW48u9+/MP9rU+B5bXCiscPhjO8xCdxZ479/OxEFwqyVwFnnz7nn57l/wXOJsdZpy+uCwAoIIBArKAIvAQJrJYBArLUyvC4IrIAAArGCIvASILBWAgjEWivD64LACgggECsoAi8BAmslgEA0Vebuf7vbu+OWsmk0U0C7Ed/SEMOZ+dij07YEEAhLo9h+82n3y793V4fi+19zWqOyaTRTQLsRv6eRMM7Mx9aCtiUwxx9K2eM90fbrq93d193nL937/+WL3c0/d1eX31LRnL7+1A1pe/G37t+jZoqggHYjvsU4Px9bC9qWwBx/av2k/1Ltr7ffcWWT+Ppqd3P7XRr0JlePNnW+u919+dq1NdetWSYunYXZiD8/H/2p9ZOen+f+S0oEwrxBx5pWIPpx+zZWj84akjTYAFYm1F+f+taR+JbGsD0tHwRiSNj2IBCWxkh7KBC9kaapThNeXn47axjxvO/qTjHuOgVJlyQly6yf+BmQbHcqPghEBjbbZQ0iA9K6qzd8y3u+u+j4cbmh8QDEr4M6N5/60bczyl2Mplpff9xd3nSLC5VNv9N0I6Nyp6P3VRCFUkC7ET/RWIRPOjqNjABnEBmQ4q4WHd986i4rhisOkoZ0I0P+wzsdfVBN/bR4OTwM8ZflM6wIPSLAGsSRaZDWIKwKpIVJ29kvQypcEou0MGmnfurMDmxDEX94+XYmPqxBZPMw20UgMiD5bhKIfsBO02Q6fM83mqUIqdHo2GiWwqZGo2OjWQqbGo2OjWYpbGo0OjaaKSwCkdiONrjEGMVS7NRvttt99/eU/WnCUBp6z0az4WEaHRvNiF8v05APPRkBBCID0rTbvz+PmjaaDeM0OjaaEX9IgJ5GAtzFaALVnTJ8zJ8gGnpqrUE/9U2PISmUAtqN+InGInzS0WlkBDiDyICM7+ruw4f/7t593r1+2d3FuPx7bmaXIUfvdMhBU19DCqJN9yz0+z9txBeKBfmkQtDICLBImQHJd/tFSjt3ZWFlwkqDliS0pQe30g3Rint/vIoB8YXofHxYpOxnYOlfBKJE5lu/vYuRTdNXP+1uv3x/KCs94pmWMBVCNyyvXnRnH/1mleVbl/mP+PPzQSDMBBxpIhAjUGyXFYi+P3sbl25kWJmQY33q2yMS39IYtqflg0AMCdseBMLSGGkPBaI36qfpq59/WEoY+ksmPvwxvmwxNLY9xLc0hu2p+CAQQ7a2h0VKS+OEttYpb349bq+VSLsYedzhwYL4DyTG/z83n/Gjbq+X25zbqzkZQ6CZAALRjApDCGyPAAKxvZqTMQSaCSAQzagwhMD2CMxxF+NJU33/4vCkXz8vvk5AdzHqBhsfneMuxpP+1OD3bzc+Q+Kn/6Tn57k/1ZpLjPhvADKEwKMJIBCPRocjBOIT2JpA6Gmq+weq4leWDCEwAYGtCcRrPRUxATZCQGAbBLYmEK92O/2wQQACTQQ2JRA6d7i8/+EkomlyYASBTQlEOndIjdYJoGcH9Vzm0U02jWYKaDfiWxpiODMfe3TalsB2BOKl+Q48fRmedk/Y9JU5+jqcyrdmaU5rVDaNZgpoN+L3NBLGmfnYWtC2BOb4Qyl7vOXa2VmDdu8/HLLtBb2+2t197T5LbvitWZrT6Zty9PlR2o6a6TNmFNBuxLcY5+dja0HbEpjjT61X8JdqurU5vELQesQXy2K0bT8wJpvE+mJefWDk8BMos0+RtGalj5/qD038+fnwgTGj0z51bkQgtCp5k3J+aFzvdu8e2sX/rUD0RvZtrJ70NXlZCCsTGqpPfetLfEtj2J6WDwIxJGx7NiIQf97fvLCJq61zin9kXcPdoUD0Npqm+qz6l5c7nSBUtu4U4667oDj1c6WIX6Gqoan4IBB1zltYg9B6pK4mhps6NXTCSoQN0fhZcp18VBXExrRt4lsaw/a5+QyPuM2eLdzF0KVEaasM/eCi78K6vDnyrVn6naYbGZU7HX1EnVMolALajfiJxiJ80tFpZATCn0FoeTK7f2EJaEgGx5cq5aPvwnrzqbus0PVCdlkhaUg3MmQ5vNPRHzJblbCvg/gisCyfrBzs9gTCr0H8a7ernyZo8fJNZTakNQirAmlh0nb2y5AKlcQiLUzaqZ86s4PaUMQfLtmciQ9rENk8zHbDC8SfhQWIxOHIUmUSiN7BTtMUYviebzRLEVKj0bHRLIVNjUbHRrMUNjUaHRvNUtjUaHRsNFNYBCKxHW3EvsTQFYRWIuubDGT2oW6URvWb7XbfLaH3pwlDaegtG81S2NRodGw0S2FTo9Gx0SyFTY1Gx0azFDY1Gh0bzVJYGiUCsQWi8aEsmbUKRM+xn38lpqm/0SzZp0ajY6NZCpsajY6NZilsajQ6NpqlsKnR6NholsLSGBJ4NuyK0qNbmI03GGUm49rWnTJ8zJ8gGjporUE/9U2PISmUAtqN+InGInzS0WlkBAKfQVRuXmQQtCvj2h9E6OaFvqH73efid/DaZcjROx06hqa+hhREm+6J6Pdb2ogvFAvySYWgkREIvEj5V5bqsd2LUYN+kdLOXZnZr+q20qAlCW3DpzMq7v1BKwbEF6Lz8WGRsp+BpX+jCsTowxclCH2/7oaOPJrRC0RvkU3TVz/tbr905wLa7GplWsJUv25YXr3ozj76zSrLty7zH/Hn54NAmAk40owqEP9pXoBIULR48EvaSQ0rEH1n9ja20pC81LAyod361LeOxLc0hu1p+SAQQ8K2J6RAaMXx2FKhZfC9rdXKfCViKBC9eT9NX/38w1LC90gPLcnEhz+6P77U19WftBG/jmsqPghEnXPIRUpdLDxuk+NvjZ56w9/8etxWK5F2MfK4w4MF8R9IjP9/bj7jR91eb0iBaH2Tb6/cZAyB0wg8O80cawhAYEsEEIgtVZtcIXAiAQTiRGCYQ2BLBOa4i7ElnuQKgVAE5likXMGnWj++Zvv9ntdfwQefChwNzcCn/gKco1xiOAHiDoHIBBCIyNUlNwg4CaxLIPSs0/3jTs6kcIcABKYhsC6B0CNW+mGDAARWQmBdAqFPZTjpUxxWApGXAYGoBFYkEDp3uLz/4SQi6mwjrydHYEUCkc4dUuPJ0eQFQyAYgbUIhJ7QTh8gefwjIoMVgXQgsFYCaxGI7Kwh210rPV4XBIITWIVA6Nbm9Y+ctcv9zh+RsAeBBQisQiBGzxdGOxcgxCEhsGECqxCI0dsWo50brhSpQ2ABAssLhJYndXdzuKlTQ2wQgMCCBJYXiOty9pWhshMjEIDAZAQWFgitRFbWGjTEUuVkpSYQBE4nsLBAHF1oOGpwesp4QAACrQQWFojK6UOfwVGD1kSxgwAETiewpEDozT+6PGmzkAEaYYHQhsCcBJYUiMbLh0azOalxLAhshMBiAmEfvqiz5tGMOh9GIXA+AosJxEkXDicZnw8WkSGwNQKLCcT1KaRPMj4lMLYQgECNwDIC8YhlhUe41PJmDAIQaCCwjEA84pLhES4N6WMCAQjUCCwgEO3Lk/aFs1RpadCGwDwEFhCI68dm9mjHxx4QPwhsncAcX72XMf4t22cXAhBYK4EFziDWioLXBQEI5AQQiJwI+xCAQCKAQCQUNCAAgZzAhb6ePO9jHwIQgMA9gTkWKQ+Hw/loS+CIX8ELnwocDQXgU0/QOcolhhMg7hCITACBiFxdcoOAkwAC4QSIOwQiE0AgIleX3CDgJIBAOAHiDoHIBBCIyNUlNwg4CSAQToC4QyAyAQQicnXJDQJOAgiEEyDuEIhMAIGIXF1yg4CTAALhBIg7BCITQCAiV5fcIOAkgEA4AeIOgcgEEIjI1SU3CDgJIBBOgLhDIDIBBCJydckNAk4CCIQTIO4QiEwAgYhcXXKDgJMAAuEEiDsEIhNAICJXl9wg4CSAQDgB4g6ByAQQiMjVJTcIOAkgEE6AuEMgMgEEInJ1yQ0CTgIIhBMg7hCITACBiFxdcoOAkwAC4QSIOwQiE0AgIleX3CDgJIBAOAHiDoHIBBCIyNUlNwg4CVzo68+dIXCHAASiEng+Q2KHw+F8R5HAEb+CFz4VOBoKwKeeoHOUSwwnQNwhEJkAAhG5uuQGAScBBMIJEHcIRCaAQESuLrlBwEkAgXACxB0CkQkgEJGrS24QcBJAIJwAcYdAZAIIROTqkhsEnAQQCCdA3CEQmQACEbm65AYBJwEEwgkQdwhEJoBARK4uuUHASQCBcALEHQKRCSAQkatLbhBwEkAgnABxh0BkAghE5OqSGwScBBAIJ0DcIRCZAAIRubrkBgEnAQTCCRB3CEQmgEBEri65QcBJAIFwAsQdApEJIBCRq0tuEHASQCCcAHGHQGQCCETk6pIbBJwEEAgnQNwhEJkAAhG5uuQGAScBBMIJEHcIRCaAQESuLrlBwEkAgXACxB0CkQkgEJGrS24QcBK40NefO0PgDgEIRCXAGUTUypIXBCYggEBMAJEQEIhKAIGIWlnygsAEBBCICSASAgJRCSAQUStLXhCYgAACMQFEQkAgKgEEImplyQsCExBAICaASAgIRCWAQEStLHlBYAICCMQEEAkBgagEEIiolSUvCExAAIGYACIhIBCVAAIRtbLkBYEJCCAQE0AkBASiEkAgolaWvCAwAQEEYgKIhIBAVAIIRNTKkhcEJiCAQEwAkRAQiEoAgYhaWfKCwAQEEIgJIBICAlEJIBBRK0teEJiAAAIxAURCQCAqAQQiamXJCwITEEAgJoBICAhEJYBARK0seUFgAgIIxAQQCQGBqAQQiKiVJS8ITEAAgZgAIiEgEJUAAhG1suQFgQkIIBATQCQEBKISQCCiVpa8IDABgf8DTJQdHhg8t4kAAAAASUVORK5CYII=",
|
|
"text/plain": [
|
|
"<PIL.Image.Image image mode=RGB size=352x288>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"\n",
|
|
"Computing new shield\n",
|
|
"LOG: Starting with explicit model creation...\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
|
|
"\n",
|
|
"def mask_fn(env: gym.Env):\n",
|
|
" return env.create_action_mask()\n",
|
|
"\n",
|
|
"def nomask_fn(env: gym.Env):\n",
|
|
" return [1.0] * 7\n",
|
|
"\n",
|
|
"def main():\n",
|
|
" #env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n",
|
|
" env = \"MiniGrid-WindyCity-Adv-v0\"\n",
|
|
"\n",
|
|
" formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
|
|
" value_for_training = 0.99\n",
|
|
" shield_comparison = \"absolute\"\n",
|
|
" shielding = ShieldingConfig.Training\n",
|
|
" \n",
|
|
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
|
|
" \n",
|
|
" env = gym.make(env, render_mode=\"rgb_array\")\n",
|
|
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
|
|
" env = RGBImgObsWrapper(env, 8)\n",
|
|
" env = ImgObsWrapper(env)\n",
|
|
" env = MiniWrapper(env)\n",
|
|
"\n",
|
|
" \n",
|
|
" env.reset()\n",
|
|
" Image.fromarray(env.render()).show()\n",
|
|
" \n",
|
|
" shield_handlers = dict()\n",
|
|
" if shield_needed(shielding):\n",
|
|
" for value in [0.9, 0.95, 0.99, 1.0]:\n",
|
|
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
|
|
"\n",
|
|
"\n",
|
|
" shield_handlers[value] = shield_handler\n",
|
|
" if shield_needed(shielding):\n",
|
|
" for value in [0.9, 0.95, 0.99, 1.0]: \n",
|
|
" create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
|
|
" print(f\"The shield for shield_value = {value}\")\n",
|
|
"\n",
|
|
" if shielding == ShieldingConfig.Training:\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
|
|
" env = ActionMasker(env, mask_fn)\n",
|
|
" print(\"Training with shield:\")\n",
|
|
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
|
|
" elif shielding == ShieldingConfig.Disabled:\n",
|
|
" env = ActionMasker(env, nomask_fn)\n",
|
|
" else:\n",
|
|
" assert(False) \n",
|
|
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
|
|
" model.set_logger(logger)\n",
|
|
" steps = 200\n",
|
|
"\n",
|
|
" assert(False)\n",
|
|
" model.learn(steps,callback=[InfoCallback()])\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == '__main__':\n",
|
|
" print(\"Starting the training\")\n",
|
|
" main()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|