{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example usage of Tempestpy" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "from sb3_contrib import MaskablePPO\n", "from sb3_contrib.common.wrappers import ActionMasker\n", "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n", "\n", "import gymnasium as gym\n", "\n", "from minigrid.core.actions import Actions\n", "from minigrid.core.constants import TILE_PIXELS\n", "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n", "\n", "import tempfile, datetime, shutil\n", "\n", "import time\n", "import os\n", "\n", "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n", "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n", "\n", "import os, sys\n", "from copy import deepcopy\n", "\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n", "\n", "def mask_fn(env: gym.Env):\n", " return env.create_action_mask()\n", "\n", "def nomask_fn(env: gym.Env):\n", " return [1.0] * 7\n", "\n", "def main():\n", " env = \"MiniGrid-LavaFaultyS15-1-v0\"\n", " \n", " formula = \"Pmax=? [G ! AgentIsOnLava]\"\n", " value_for_training = 0.0\n", " shield_comparison = \"absolute\"\n", " shielding = ShieldingConfig.Training\n", " \n", " logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n", " \n", " env = gym.make(env, render_mode=\"rgb_array\")\n", " image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n", " env = RGBImgObsWrapper(env, 8)\n", " env = ImgObsWrapper(env)\n", " env = MiniWrapper(env)\n", "\n", " \n", " env.reset()\n", " Image.fromarray(env.render()).show()\n", " \n", " shield_handlers = dict()\n", " if shield_needed(shielding):\n", " for value in [0.0, 1.0]: \n", " shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=False, prism_file=None)\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n", " create_shield_overlay_image(image_env, shield_handler.create_shield())\n", " shield_handlers[value] = shield_handler\n", "\n", "\n", " if shielding == ShieldingConfig.Training:\n", " env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n", " env = ActionMasker(env, mask_fn)\n", " print(\"Training with shield:\")\n", " create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n", " elif shielding == ShieldingConfig.Disabled:\n", " env = ActionMasker(env, nomask_fn)\n", " else:\n", " assert(False) \n", " model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n", " model.set_logger(logger)\n", " steps = 20_000\n", "\n", " assert(False)\n", " model.learn(steps,callback=[InfoCallback()])\n", "\n", "\n", "\n", "if __name__ == '__main__':\n", " print(\"Starting the training\")\n", " main()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 4 }