{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example usage of Tempestpy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "from sb3_contrib import MaskablePPO\n",
    "from sb3_contrib.common.wrappers import ActionMasker\n",
    "from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat\n",
    "\n",
    "import gymnasium as gym\n",
    "\n",
    "from minigrid.core.actions import Actions\n",
    "from minigrid.core.constants import TILE_PIXELS\n",
    "from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper\n",
    "\n",
    "import tempfile, datetime, shutil\n",
    "\n",
    "import time\n",
    "import os\n",
    "\n",
    "from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation, create_shield_overlay_image\n",
    "from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback\n",
    "\n",
    "import os, sys\n",
    "from copy import deepcopy\n",
    "\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "GRID_TO_PRISM_BINARY=os.getenv(\"M2P_BINARY\")\n",
    "\n",
    "def mask_fn(env: gym.Env):\n",
    "    return env.create_action_mask()\n",
    "\n",
    "def nomask_fn(env: gym.Env):\n",
    "    return [1.0] * 7\n",
    "\n",
    "def main():\n",
    "    #env = \"MiniGrid-LavaSlipperyCliff-16x13-Slip10-Time-v0\"\n",
    "    env = \"MiniGrid-WindyCity2-v0\"\n",
    "\n",
    "    formula = \"Pmax=? [G ! AgentIsOnLava]\"\n",
    "    value_for_training = 0.99\n",
    "    shield_comparison =  \"absolute\"\n",
    "    shielding = ShieldingConfig.Training\n",
    "    \n",
    "    logger = Logger(\"/tmp\", output_formats=[HumanOutputFormat(sys.stdout)])\n",
    "    \n",
    "    env = gym.make(env, render_mode=\"rgb_array\")\n",
    "    image_env = RGBImgObsWrapper(env, TILE_PIXELS)\n",
    "    env = RGBImgObsWrapper(env, 8)\n",
    "    env = ImgObsWrapper(env)\n",
    "    env = MiniWrapper(env)\n",
    "\n",
    "    \n",
    "    env.reset()\n",
    "    Image.fromarray(env.render()).show()\n",
    "    \n",
    "    shield_handlers = dict()\n",
    "    if shield_needed(shielding):\n",
    "        for value in [0.9, 0.95, 0.99, 1.0]:\n",
    "            shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, \"grid.txt\", \"grid.prism\", formula, shield_value=value, shield_comparison=shield_comparison, nocleanup=True, prism_file=None)\n",
    "            env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)\n",
    "\n",
    "\n",
    "            shield_handlers[value] = shield_handler\n",
    "    if shield_needed(shielding):\n",
    "        for value in [0.9, 0.95, 0.99, 1.0]:            \n",
    "            create_shield_overlay_image(image_env, shield_handlers[value].create_shield())\n",
    "            print(f\"The shield for shield_value = {value}\")\n",
    "\n",
    "    if shielding == ShieldingConfig.Training:\n",
    "        env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handlers[value_for_training], create_shield_at_reset=False)\n",
    "        env = ActionMasker(env, mask_fn)\n",
    "        print(\"Training with shield:\")\n",
    "        create_shield_overlay_image(image_env, shield_handlers[value_for_training].create_shield())\n",
    "    elif shielding == ShieldingConfig.Disabled:\n",
    "        env = ActionMasker(env, nomask_fn)\n",
    "    else:\n",
    "        assert(False) \n",
    "    model = MaskablePPO(\"CnnPolicy\", env, verbose=1, device=\"auto\")\n",
    "    model.set_logger(logger)\n",
    "    steps = 200\n",
    "\n",
    "    model.learn(steps,callback=[InfoCallback()])\n",
    "\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    print(\"Starting the training\")\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}