sp
2 months ago
4 changed files with 164 additions and 1564 deletions
-
574notebooks/FaultyActions.ipynb
-
228notebooks/GSW_Playground.ipynb
-
147notebooks/HelloLavaGap.ipynb
-
779notebooks/SlipperyCliff.ipynb
574
notebooks/FaultyActions.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
228
notebooks/GSW_Playground.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
@ -0,0 +1,147 @@ |
|||||
|
{ |
||||
|
"cells": [ |
||||
|
{ |
||||
|
"cell_type": "markdown", |
||||
|
"metadata": {}, |
||||
|
"source": [ |
||||
|
"## Example usage of Tempestpy" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"metadata": { |
||||
|
"vscode": { |
||||
|
"languageId": "plaintext" |
||||
|
} |
||||
|
}, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"from sb3_contrib import MaskablePPO\n", |
||||
|
"from sb3_contrib.common.wrappers import ActionMasker\n", |
||||
|
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", |
||||
|
"\n", |
||||
|
"import gymnasium as gym\n", |
||||
|
"\n", |
||||
|
"from minigrid.core.actions import Actions\n", |
||||
|
"from minigrid.core.constants import TILE_PIXELS\n", |
||||
|
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", |
||||
|
"\n", |
||||
|
"import tempfile, datetime, shutil\n", |
||||
|
"\n", |
||||
|
"import time\n", |
||||
|
"import os\n", |
||||
|
"\n", |
||||
|
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", |
||||
|
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", |
||||
|
"\n", |
||||
|
"import os, sys\n", |
||||
|
"from copy import deepcopy\n", |
||||
|
"\n", |
||||
|
"from PIL import Image" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"metadata": { |
||||
|
"vscode": { |
||||
|
"languageId": "plaintext" |
||||
|
} |
||||
|
}, |
||||
|
"outputs": [], |
||||
|
"source": [ |
||||
|
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", |
||||
|
"\n", |
||||
|
"def mask_fn(env: gym.Env):\n", |
||||
|
" return env.create_action_mask()\n", |
||||
|
"\n", |
||||
|
"def nomask_fn(env: gym.Env):\n", |
||||
|
" return [1.0] * 7\n", |
||||
|
"\n", |
||||
|
"def main():\n", |
||||
|
" env = \"MiniGrid-LavaGapS6-v0\"\n", |
||||
|
"\n", |
||||
|
" # TODO Change the safety specification\n", |
||||
|
" formula = \"Pmax=? [G !AgentIsOnLava]\"\n", |
||||
|
" value_for_training = 1.0\n", |
||||
|
" shield_comparison = \"absolute\"\n", |
||||
|
" shielding = ShieldingConfig.Training\n", |
||||
|
" \n", |
||||
|
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", |
||||
|
" \n", |
||||
|
"\n", |
||||
|
" env = gym.make(env, render_mode=\"rgb_array\")\n", |
||||
|
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", |
||||
|
" env = RGBImgObsWrapper(env, 8)\n", |
||||
|
" env = ImgObsWrapper(env)\n", |
||||
|
" env = MiniWrapper(env)\n", |
||||
|
"\n", |
||||
|
" \n", |
||||
|
" env.reset()\n", |
||||
|
" Image.fromarray(env.render()).show()\n", |
||||
|
" \n", |
||||
|
" shield_handlers = dict()\n", |
||||
|
" if shield_needed(shielding):\n", |
||||
|
" for value in [0.0, 1.0]:\n", |
||||
|
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n", |
||||
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", |
||||
|
" create_shield_overlay_image(image_env, shield_handler.create_shield())\n", |
||||
|
" print(f\"The shield for shield_value = {value}\")\n", |
||||
|
"\n", |
||||
|
" shield_handlers[value] = shield_handler\n", |
||||
|
"\n", |
||||
|
"\n", |
||||
|
" if shielding == ShieldingConfig.Training:\n", |
||||
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", |
||||
|
" env = ActionMasker(env, mask_fn)\n", |
||||
|
" print(\"Training with shield:\")\n", |
||||
|
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", |
||||
|
" elif shielding == ShieldingConfig.Disabled:\n", |
||||
|
" env = ActionMasker(env, nomask_fn)\n", |
||||
|
" else:\n", |
||||
|
" assert(False) \n", |
||||
|
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", |
||||
|
" model.set_logger(logger)\n", |
||||
|
" steps = 20_000\n", |
||||
|
"\n", |
||||
|
" \n", |
||||
|
" model.learn(steps,callback=[InfoCallback()])\n", |
||||
|
"\n", |
||||
|
"\n", |
||||
|
"\n", |
||||
|
"if __name__ == '__main__':\n", |
||||
|
" print(\"Starting the training\")\n", |
||||
|
" main()" |
||||
|
] |
||||
|
}, |
||||
|
{ |
||||
|
"cell_type": "code", |
||||
|
"execution_count": null, |
||||
|
"metadata": {}, |
||||
|
"outputs": [], |
||||
|
"source": [] |
||||
|
} |
||||
|
], |
||||
|
"metadata": { |
||||
|
"kernelspec": { |
||||
|
"display_name": "Python 3 (ipykernel)", |
||||
|
"language": "python", |
||||
|
"name": "python3" |
||||
|
}, |
||||
|
"language_info": { |
||||
|
"codemirror_mode": { |
||||
|
"name": "ipython", |
||||
|
"version": 3 |
||||
|
}, |
||||
|
"file_extension": ".py", |
||||
|
"mimetype": "text/x-python", |
||||
|
"name": "python", |
||||
|
"nbconvert_exporter": "python", |
||||
|
"pygments_lexer": "ipython3", |
||||
|
"version": "3.10.12" |
||||
|
} |
||||
|
}, |
||||
|
"nbformat": 4, |
||||
|
"nbformat_minor": 4 |
||||
|
} |
779
notebooks/SlipperyCliff.ipynb
File diff suppressed because it is too large
View File
File diff suppressed because it is too large
View File
Write
Preview
Loading…
Cancel
Save
Reference in new issue