You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

147 lines
4.7 KiB

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example usage of Tempestpy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from sb3_contrib import MaskablePPO\n",
"from sb3_contrib.common.wrappers import ActionMasker\n",
"from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
"\n",
"import gymnasium as gym\n",
"\n",
"from minigrid.core.actions import Actions\n",
"from minigrid.core.constants import TILE_PIXELS\n",
"from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
"\n",
"import tempfile, datetime, shutil\n",
"\n",
"import time\n",
"import os\n",
"\n",
"from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
"from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
"\n",
"import os, sys\n",
"from copy import deepcopy\n",
"\n",
"from PIL import Image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
"\n",
"def mask_fn(env: gym.Env):\n",
" return env.create_action_mask()\n",
"\n",
"def nomask_fn(env: gym.Env):\n",
" return [1.0] * 7\n",
"\n",
"def main():\n",
" #env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n",
" env = \"MiniGrid-WindyCity2-v0\"\n",
"\n",
" formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
" value_for_training = 0.99\n",
" shield_comparison = \"absolute\"\n",
" shielding = ShieldingConfig.Training\n",
" \n",
" logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
" \n",
" env = gym.make(env, render_mode=\"rgb_array\")\n",
" image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
" env = RGBImgObsWrapper(env, 8)\n",
" env = ImgObsWrapper(env)\n",
" env = MiniWrapper(env)\n",
"\n",
" \n",
" env.reset()\n",
" Image.fromarray(env.render()).show()\n",
" \n",
" shield_handlers = dict()\n",
" if shield_needed(shielding):\n",
" for value in [0.9, 0.95, 0.99, 1.0]:\n",
" shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
"\n",
"\n",
" shield_handlers[value] = shield_handler\n",
" if shield_needed(shielding):\n",
" for value in [0.9, 0.95, 0.99, 1.0]: \n",
" create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
" print(f\"The shield for shield_value = {value}\")\n",
"\n",
" if shielding == ShieldingConfig.Training:\n",
" env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
" env = ActionMasker(env, mask_fn)\n",
" print(\"Training with shield:\")\n",
" create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
" elif shielding == ShieldingConfig.Disabled:\n",
" env = ActionMasker(env, nomask_fn)\n",
" else:\n",
" assert(False) \n",
" model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
" model.set_logger(logger)\n",
" steps = 200\n",
"\n",
" model.learn(steps,callback=[InfoCallback()])\n",
"\n",
"\n",
"\n",
"if __name__ == '__main__':\n",
" print(\"Starting the training\")\n",
" main()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}