sp
2 months ago
4 changed files with 164 additions and 1564 deletions
-
574notebooks/FaultyActions.ipynb
-
228notebooks/GSW_Playground.ipynb
-
147notebooks/HelloLavaGap.ipynb
-
779notebooks/SlipperyCliff.ipynb
574
notebooks/FaultyActions.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
228
notebooks/GSW_Playground.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
@ -0,0 +1,147 @@ |
|||
{ |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"## Example usage of Tempestpy" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"vscode": { |
|||
"languageId": "plaintext" |
|||
} |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"from sb3_contrib import MaskablePPO\n", |
|||
"from sb3_contrib.common.wrappers import ActionMasker\n", |
|||
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", |
|||
"\n", |
|||
"import gymnasium as gym\n", |
|||
"\n", |
|||
"from minigrid.core.actions import Actions\n", |
|||
"from minigrid.core.constants import TILE_PIXELS\n", |
|||
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", |
|||
"\n", |
|||
"import tempfile, datetime, shutil\n", |
|||
"\n", |
|||
"import time\n", |
|||
"import os\n", |
|||
"\n", |
|||
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", |
|||
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", |
|||
"\n", |
|||
"import os, sys\n", |
|||
"from copy import deepcopy\n", |
|||
"\n", |
|||
"from PIL import Image" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"vscode": { |
|||
"languageId": "plaintext" |
|||
} |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", |
|||
"\n", |
|||
"def mask_fn(env: gym.Env):\n", |
|||
" return env.create_action_mask()\n", |
|||
"\n", |
|||
"def nomask_fn(env: gym.Env):\n", |
|||
" return [1.0] * 7\n", |
|||
"\n", |
|||
"def main():\n", |
|||
" env = \"MiniGrid-LavaGapS6-v0\"\n", |
|||
"\n", |
|||
" # TODO Change the safety specification\n", |
|||
" formula = \"Pmax=? [G !AgentIsOnLava]\"\n", |
|||
" value_for_training = 1.0\n", |
|||
" shield_comparison = \"absolute\"\n", |
|||
" shielding = ShieldingConfig.Training\n", |
|||
" \n", |
|||
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", |
|||
" \n", |
|||
"\n", |
|||
" env = gym.make(env, render_mode=\"rgb_array\")\n", |
|||
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", |
|||
" env = RGBImgObsWrapper(env, 8)\n", |
|||
" env = ImgObsWrapper(env)\n", |
|||
" env = MiniWrapper(env)\n", |
|||
"\n", |
|||
" \n", |
|||
" env.reset()\n", |
|||
" Image.fromarray(env.render()).show()\n", |
|||
" \n", |
|||
" shield_handlers = dict()\n", |
|||
" if shield_needed(shielding):\n", |
|||
" for value in [0.0, 1.0]:\n", |
|||
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", |
|||
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", |
|||
" create_shield_overlay_image(image_env, shield_handler.create_shield())\n", |
|||
" print(f\"The shield for shield_value = {value}\")\n", |
|||
"\n", |
|||
" shield_handlers[value] = shield_handler\n", |
|||
"\n", |
|||
"\n", |
|||
" if shielding == ShieldingConfig.Training:\n", |
|||
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", |
|||
" env = ActionMasker(env, mask_fn)\n", |
|||
" print(\"Training with shield:\")\n", |
|||
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", |
|||
" elif shielding == ShieldingConfig.Disabled:\n", |
|||
" env = ActionMasker(env, nomask_fn)\n", |
|||
" else:\n", |
|||
" assert(False) \n", |
|||
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", |
|||
" model.set_logger(logger)\n", |
|||
" steps = 20_000\n", |
|||
"\n", |
|||
" \n", |
|||
" model.learn(steps,callback=[InfoCallback()])\n", |
|||
"\n", |
|||
"\n", |
|||
"\n", |
|||
"if __name__ == '__main__':\n", |
|||
" print(\"Starting the training\")\n", |
|||
" main()" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": {}, |
|||
"outputs": [], |
|||
"source": [] |
|||
} |
|||
], |
|||
"metadata": { |
|||
"kernelspec": { |
|||
"display_name": "Python 3 (ipykernel)", |
|||
"language": "python", |
|||
"name": "python3" |
|||
}, |
|||
"language_info": { |
|||
"codemirror_mode": { |
|||
"name": "ipython", |
|||
"version": 3 |
|||
}, |
|||
"file_extension": ".py", |
|||
"mimetype": "text/x-python", |
|||
"name": "python", |
|||
"nbconvert_exporter": "python", |
|||
"pygments_lexer": "ipython3", |
|||
"version": "3.10.12" |
|||
} |
|||
}, |
|||
"nbformat": 4, |
|||
"nbformat_minor": 4 |
|||
} |
779
notebooks/SlipperyCliff.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
Reference in new issue