4 changed files with 164 additions and 1564 deletions
			
			
		- 
					574notebooks/FaultyActions.ipynb
- 
					228notebooks/GSW_Playground.ipynb
- 
					147notebooks/HelloLavaGap.ipynb
- 
					779notebooks/SlipperyCliff.ipynb
						
							
						
						
							574
	
						
						notebooks/FaultyActions.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						
						
							
						
						
							228
	
						
						notebooks/GSW_Playground.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						| @ -0,0 +1,147 @@ | |||
| { | |||
|  "cells": [ | |||
|   { | |||
|    "cell_type": "markdown", | |||
|    "metadata": {}, | |||
|    "source": [ | |||
|     "## Example usage of Tempestpy" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": { | |||
|     "vscode": { | |||
|      "languageId": "plaintext" | |||
|     } | |||
|    }, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "from sb3_contrib import MaskablePPO\n", | |||
|     "from sb3_contrib.common.wrappers import ActionMasker\n", | |||
|     "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", | |||
|     "\n", | |||
|     "import gymnasium as gym\n", | |||
|     "\n", | |||
|     "from minigrid.core.actions import Actions\n", | |||
|     "from minigrid.core.constants import TILE_PIXELS\n", | |||
|     "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", | |||
|     "\n", | |||
|     "import tempfile, datetime, shutil\n", | |||
|     "\n", | |||
|     "import time\n", | |||
|     "import os\n", | |||
|     "\n", | |||
|     "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", | |||
|     "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", | |||
|     "\n", | |||
|     "import os, sys\n", | |||
|     "from copy import deepcopy\n", | |||
|     "\n", | |||
|     "from PIL import Image" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": { | |||
|     "vscode": { | |||
|      "languageId": "plaintext" | |||
|     } | |||
|    }, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", | |||
|     "\n", | |||
|     "def mask_fn(env: gym.Env):\n", | |||
|     "    return env.create_action_mask()\n", | |||
|     "\n", | |||
|     "def nomask_fn(env: gym.Env):\n", | |||
|     "    return [1.0] * 7\n", | |||
|     "\n", | |||
|     "def main():\n", | |||
|     "    env = \"MiniGrid-LavaGapS6-v0\"\n", | |||
|     "\n", | |||
|     "    # TODO Change the safety specification\n", | |||
|     "    formula = \"Pmax=? [G !AgentIsOnLava]\"\n", | |||
|     "    value_for_training = 1.0\n", | |||
|     "    shield_comparison =  \"absolute\"\n", | |||
|     "    shielding = ShieldingConfig.Training\n", | |||
|     "    \n", | |||
|     "    logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", | |||
|     "    \n", | |||
|     "\n", | |||
|     "    env = gym.make(env, render_mode=\"rgb_array\")\n", | |||
|     "    image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", | |||
|     "    env = RGBImgObsWrapper(env, 8)\n", | |||
|     "    env = ImgObsWrapper(env)\n", | |||
|     "    env = MiniWrapper(env)\n", | |||
|     "\n", | |||
|     "    \n", | |||
|     "    env.reset()\n", | |||
|     "    Image.fromarray(env.render()).show()\n", | |||
|     "        \n", | |||
|     "    shield_handlers = dict()\n", | |||
|     "    if shield_needed(shielding):\n", | |||
|     "        for value in [0.0, 1.0]:\n", | |||
|     "            shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", | |||
|     "            env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", | |||
|     "            create_shield_overlay_image(image_env, shield_handler.create_shield())\n", | |||
|     "            print(f\"The shield for shield_value = {value}\")\n", | |||
|     "\n", | |||
|     "            shield_handlers[value] = shield_handler\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "    if shielding == ShieldingConfig.Training:\n", | |||
|     "        env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", | |||
|     "        env = ActionMasker(env, mask_fn)\n", | |||
|     "        print(\"Training with shield:\")\n", | |||
|     "        create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", | |||
|     "    elif shielding == ShieldingConfig.Disabled:\n", | |||
|     "        env = ActionMasker(env, nomask_fn)\n", | |||
|     "    else:\n", | |||
|     "        assert(False) \n", | |||
|     "    model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", | |||
|     "    model.set_logger(logger)\n", | |||
|     "    steps = 20_000\n", | |||
|     "\n", | |||
|     " \n", | |||
|     "    model.learn(steps,callback=[InfoCallback()])\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "if __name__ == '__main__':\n", | |||
|     "    print(\"Starting the training\")\n", | |||
|     "    main()" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [] | |||
|   } | |||
|  ], | |||
|  "metadata": { | |||
|   "kernelspec": { | |||
|    "display_name": "Python 3 (ipykernel)", | |||
|    "language": "python", | |||
|    "name": "python3" | |||
|   }, | |||
|   "language_info": { | |||
|    "codemirror_mode": { | |||
|     "name": "ipython", | |||
|     "version": 3 | |||
|    }, | |||
|    "file_extension": ".py", | |||
|    "mimetype": "text/x-python", | |||
|    "name": "python", | |||
|    "nbconvert_exporter": "python", | |||
|    "pygments_lexer": "ipython3", | |||
|    "version": "3.10.12" | |||
|   } | |||
|  }, | |||
|  "nbformat": 4, | |||
|  "nbformat_minor": 4 | |||
| } | |||
						
							
						
						
							779
	
						
						notebooks/SlipperyCliff.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue