You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
143 lines
4.4 KiB
143 lines
4.4 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example usage of Tempestpy"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sb3_contrib import MaskablePPO\n",
|
|
"from sb3_contrib.common.wrappers import ActionMasker\n",
|
|
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
|
|
"\n",
|
|
"import gymnasium as gym\n",
|
|
"\n",
|
|
"from minigrid.core.actions import Actions\n",
|
|
"from minigrid.core.constants import TILE_PIXELS\n",
|
|
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
|
|
"\n",
|
|
"import tempfile, datetime, shutil\n",
|
|
"\n",
|
|
"import time\n",
|
|
"import os\n",
|
|
"\n",
|
|
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
|
|
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
|
|
"\n",
|
|
"import os, sys\n",
|
|
"from copy import deepcopy\n",
|
|
"\n",
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"vscode": {
|
|
"languageId": "plaintext"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
|
|
"\n",
|
|
"def mask_fn(env: gym.Env):\n",
|
|
" return env.create_action_mask()\n",
|
|
"\n",
|
|
"def nomask_fn(env: gym.Env):\n",
|
|
" return [1.0] * 7\n",
|
|
"\n",
|
|
"def main():\n",
|
|
" env = \"MiniGrid-LavaFaultyS15-1-v0\"\n",
|
|
" \n",
|
|
" formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
|
|
" value_for_training = 0.0\n",
|
|
" shield_comparison = \"absolute\"\n",
|
|
" shielding = ShieldingConfig.Training\n",
|
|
" \n",
|
|
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
|
|
" \n",
|
|
" env = gym.make(env, render_mode=\"rgb_array\")\n",
|
|
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
|
|
" env = RGBImgObsWrapper(env, 8)\n",
|
|
" env = ImgObsWrapper(env)\n",
|
|
" env = MiniWrapper(env)\n",
|
|
"\n",
|
|
" \n",
|
|
" env.reset()\n",
|
|
" Image.fromarray(env.render()).show()\n",
|
|
" \n",
|
|
" shield_handlers = dict()\n",
|
|
" if shield_needed(shielding):\n",
|
|
" for value in [0.0, 1.0]: \n",
|
|
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
|
|
" create_shield_overlay_image(image_env, shield_handler.create_shield())\n",
|
|
" shield_handlers[value] = shield_handler\n",
|
|
"\n",
|
|
"\n",
|
|
" if shielding == ShieldingConfig.Training:\n",
|
|
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
|
|
" env = ActionMasker(env, mask_fn)\n",
|
|
" print(\"Training with shield:\")\n",
|
|
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
|
|
" elif shielding == ShieldingConfig.Disabled:\n",
|
|
" env = ActionMasker(env, nomask_fn)\n",
|
|
" else:\n",
|
|
" assert(False) \n",
|
|
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
|
|
" model.set_logger(logger)\n",
|
|
" steps = 20_000\n",
|
|
"\n",
|
|
" assert(False)\n",
|
|
" model.learn(steps,callback=[InfoCallback()])\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"if __name__ == '__main__':\n",
|
|
" print(\"Starting the training\")\n",
|
|
" main()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|