4 changed files with 164 additions and 1564 deletions
			
			
		- 
					574notebooks/FaultyActions.ipynb
 - 
					228notebooks/GSW_Playground.ipynb
 - 
					147notebooks/HelloLavaGap.ipynb
 - 
					779notebooks/SlipperyCliff.ipynb
 
						
							
						
						
							574
	
						
						notebooks/FaultyActions.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						
						
							
						
						
							228
	
						
						notebooks/GSW_Playground.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						@ -0,0 +1,147 @@ | 
			
		|||||
 | 
				{ | 
			
		||||
 | 
				 "cells": [ | 
			
		||||
 | 
				  { | 
			
		||||
 | 
				   "cell_type": "markdown", | 
			
		||||
 | 
				   "metadata": {}, | 
			
		||||
 | 
				   "source": [ | 
			
		||||
 | 
				    "## Example usage of Tempestpy" | 
			
		||||
 | 
				   ] | 
			
		||||
 | 
				  }, | 
			
		||||
 | 
				  { | 
			
		||||
 | 
				   "cell_type": "code", | 
			
		||||
 | 
				   "execution_count": null, | 
			
		||||
 | 
				   "metadata": { | 
			
		||||
 | 
				    "vscode": { | 
			
		||||
 | 
				     "languageId": "plaintext" | 
			
		||||
 | 
				    } | 
			
		||||
 | 
				   }, | 
			
		||||
 | 
				   "outputs": [], | 
			
		||||
 | 
				   "source": [ | 
			
		||||
 | 
				    "from sb3_contrib import MaskablePPO\n", | 
			
		||||
 | 
				    "from sb3_contrib.common.wrappers import ActionMasker\n", | 
			
		||||
 | 
				    "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "import gymnasium as gym\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "from minigrid.core.actions import Actions\n", | 
			
		||||
 | 
				    "from minigrid.core.constants import TILE_PIXELS\n", | 
			
		||||
 | 
				    "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "import tempfile, datetime, shutil\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "import time\n", | 
			
		||||
 | 
				    "import os\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", | 
			
		||||
 | 
				    "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "import os, sys\n", | 
			
		||||
 | 
				    "from copy import deepcopy\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "from PIL import Image" | 
			
		||||
 | 
				   ] | 
			
		||||
 | 
				  }, | 
			
		||||
 | 
				  { | 
			
		||||
 | 
				   "cell_type": "code", | 
			
		||||
 | 
				   "execution_count": null, | 
			
		||||
 | 
				   "metadata": { | 
			
		||||
 | 
				    "vscode": { | 
			
		||||
 | 
				     "languageId": "plaintext" | 
			
		||||
 | 
				    } | 
			
		||||
 | 
				   }, | 
			
		||||
 | 
				   "outputs": [], | 
			
		||||
 | 
				   "source": [ | 
			
		||||
 | 
				    "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "def mask_fn(env: gym.Env):\n", | 
			
		||||
 | 
				    "    return env.create_action_mask()\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "def nomask_fn(env: gym.Env):\n", | 
			
		||||
 | 
				    "    return [1.0] * 7\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "def main():\n", | 
			
		||||
 | 
				    "    env = \"MiniGrid-LavaGapS6-v0\"\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "    # TODO Change the safety specification\n", | 
			
		||||
 | 
				    "    formula = \"Pmax=? [G !AgentIsOnLava]\"\n", | 
			
		||||
 | 
				    "    value_for_training = 1.0\n", | 
			
		||||
 | 
				    "    shield_comparison =  \"absolute\"\n", | 
			
		||||
 | 
				    "    shielding = ShieldingConfig.Training\n", | 
			
		||||
 | 
				    "    \n", | 
			
		||||
 | 
				    "    logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", | 
			
		||||
 | 
				    "    \n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "    env = gym.make(env, render_mode=\"rgb_array\")\n", | 
			
		||||
 | 
				    "    image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", | 
			
		||||
 | 
				    "    env = RGBImgObsWrapper(env, 8)\n", | 
			
		||||
 | 
				    "    env = ImgObsWrapper(env)\n", | 
			
		||||
 | 
				    "    env = MiniWrapper(env)\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "    \n", | 
			
		||||
 | 
				    "    env.reset()\n", | 
			
		||||
 | 
				    "    Image.fromarray(env.render()).show()\n", | 
			
		||||
 | 
				    "        \n", | 
			
		||||
 | 
				    "    shield_handlers = dict()\n", | 
			
		||||
 | 
				    "    if shield_needed(shielding):\n", | 
			
		||||
 | 
				    "        for value in [0.0, 1.0]:\n", | 
			
		||||
 | 
				    "            shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", | 
			
		||||
 | 
				    "            env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", | 
			
		||||
 | 
				    "            create_shield_overlay_image(image_env, shield_handler.create_shield())\n", | 
			
		||||
 | 
				    "            print(f\"The shield for shield_value = {value}\")\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "            shield_handlers[value] = shield_handler\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "    if shielding == ShieldingConfig.Training:\n", | 
			
		||||
 | 
				    "        env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", | 
			
		||||
 | 
				    "        env = ActionMasker(env, mask_fn)\n", | 
			
		||||
 | 
				    "        print(\"Training with shield:\")\n", | 
			
		||||
 | 
				    "        create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", | 
			
		||||
 | 
				    "    elif shielding == ShieldingConfig.Disabled:\n", | 
			
		||||
 | 
				    "        env = ActionMasker(env, nomask_fn)\n", | 
			
		||||
 | 
				    "    else:\n", | 
			
		||||
 | 
				    "        assert(False) \n", | 
			
		||||
 | 
				    "    model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", | 
			
		||||
 | 
				    "    model.set_logger(logger)\n", | 
			
		||||
 | 
				    "    steps = 20_000\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    " \n", | 
			
		||||
 | 
				    "    model.learn(steps,callback=[InfoCallback()])\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "\n", | 
			
		||||
 | 
				    "if __name__ == '__main__':\n", | 
			
		||||
 | 
				    "    print(\"Starting the training\")\n", | 
			
		||||
 | 
				    "    main()" | 
			
		||||
 | 
				   ] | 
			
		||||
 | 
				  }, | 
			
		||||
 | 
				  { | 
			
		||||
 | 
				   "cell_type": "code", | 
			
		||||
 | 
				   "execution_count": null, | 
			
		||||
 | 
				   "metadata": {}, | 
			
		||||
 | 
				   "outputs": [], | 
			
		||||
 | 
				   "source": [] | 
			
		||||
 | 
				  } | 
			
		||||
 | 
				 ], | 
			
		||||
 | 
				 "metadata": { | 
			
		||||
 | 
				  "kernelspec": { | 
			
		||||
 | 
				   "display_name": "Python 3 (ipykernel)", | 
			
		||||
 | 
				   "language": "python", | 
			
		||||
 | 
				   "name": "python3" | 
			
		||||
 | 
				  }, | 
			
		||||
 | 
				  "language_info": { | 
			
		||||
 | 
				   "codemirror_mode": { | 
			
		||||
 | 
				    "name": "ipython", | 
			
		||||
 | 
				    "version": 3 | 
			
		||||
 | 
				   }, | 
			
		||||
 | 
				   "file_extension": ".py", | 
			
		||||
 | 
				   "mimetype": "text/x-python", | 
			
		||||
 | 
				   "name": "python", | 
			
		||||
 | 
				   "nbconvert_exporter": "python", | 
			
		||||
 | 
				   "pygments_lexer": "ipython3", | 
			
		||||
 | 
				   "version": "3.10.12" | 
			
		||||
 | 
				  } | 
			
		||||
 | 
				 }, | 
			
		||||
 | 
				 "nbformat": 4, | 
			
		||||
 | 
				 "nbformat_minor": 4 | 
			
		||||
 | 
				} | 
			
		||||
						
							
						
						
							779
	
						
						notebooks/SlipperyCliff.ipynb
						
							File diff suppressed because it is too large
							
							
								
									View File
								
							
						
					
				File diff suppressed because it is too large
							
							
								
									View File
								
							
						
		Reference in new issue