10 changed files with 587 additions and 50 deletions
			
			
		- 
					13examples/shields/rl/11_minigridrl.py
- 
					5examples/shields/rl/13_minigridsb.py
- 
					19examples/shields/rl/14_train_eval.py
- 
					28examples/shields/rl/15_train_eval_tune.py
- 
					116examples/shields/rl/dqn_rllib.ipynb
- 
					117examples/shields/rl/ppo_rllib.ipynb
- 
					337examples/shields/rl/ppo_sb.ipynb
- 
					0examples/shields/rl/shieldhandlers.py
- 
					0examples/shields/rl/torch_action_mask_model.py
- 
					2examples/shields/rl/wrappers.py
| @ -0,0 +1,116 @@ | |||
| { | |||
|  "cells": [ | |||
|   { | |||
|    "cell_type": "markdown", | |||
|    "metadata": {}, | |||
|    "source": [ | |||
|     "## Example how to combine shielding with rllibs dqn algorithm." | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "import gymnasium as gym\n", | |||
|     "\n", | |||
|     "import minigrid\n", | |||
|     "\n", | |||
|     "from ray.tune import register_env\n", | |||
|     "from ray.rllib.algorithms.dqn.dqn import DQNConfig\n", | |||
|     "from ray.tune.logger import pretty_print\n", | |||
|     "from ray.rllib.models import ModelCatalog\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "from torch_action_mask_model import TorchActionMaskModel\n", | |||
|     "from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper\n", | |||
|     "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n", | |||
|     "  " | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 2, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "def shielding_env_creater(config):\n", | |||
|     "    name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n", | |||
|     "    framestack = config.get(\"framestack\", 4)\n", | |||
|     "    \n", | |||
|     "    shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", | |||
|     "    \n", | |||
|     "    env = gym.make(name)\n", | |||
|     "    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n", | |||
|     "    env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n", | |||
|     "                                 framestack=framestack)\n", | |||
|     "    \n", | |||
|     "    return env\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "def register_minigrid_shielding_env():\n", | |||
|     "    env_name = \"mini-grid-shielding\"\n", | |||
|     "    register_env(env_name, shielding_env_creater)\n", | |||
|     "    ModelCatalog.register_custom_model(\n", | |||
|     "        \"shielding_model\", \n", | |||
|     "        TorchActionMaskModel)" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "register_minigrid_shielding_env()\n", | |||
|     "\n", | |||
|     "    \n", | |||
|     "config = DQNConfig()\n", | |||
|     "config = config.resources(num_gpus=0)\n", | |||
|     "config = config.rollouts(num_rollout_workers=1)\n", | |||
|     "config = config.environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\" })\n", | |||
|     "config = config.framework(\"torch\")\n", | |||
|     "config = config.rl_module(_enable_rl_module_api = False)\n", | |||
|     "config = config.training(hiddens=[], dueling=False, model={    \n", | |||
|     "        \"custom_model\": \"shielding_model\"\n", | |||
|     "})\n", | |||
|     "    \n", | |||
|     "algo = (\n", | |||
|     "    config.build()\n", | |||
|     ")\n", | |||
|     "        \n", | |||
|     "for i in range(30):\n", | |||
|     "    result = algo.train()\n", | |||
|     "    print(pretty_print(result))\n", | |||
|     "\n", | |||
|     "    if i % 5 == 0:\n", | |||
|     "        print(\"Saving checkpoint\")\n", | |||
|     "        checkpoint_dir = algo.save()\n", | |||
|     "        print(f\"Checkpoint saved in directory {checkpoint_dir}\")\n" | |||
|    ] | |||
|   } | |||
|  ], | |||
|  "metadata": { | |||
|   "kernelspec": { | |||
|    "display_name": "env", | |||
|    "language": "python", | |||
|    "name": "python3" | |||
|   }, | |||
|   "language_info": { | |||
|    "codemirror_mode": { | |||
|     "name": "ipython", | |||
|     "version": 3 | |||
|    }, | |||
|    "file_extension": ".py", | |||
|    "mimetype": "text/x-python", | |||
|    "name": "python", | |||
|    "nbconvert_exporter": "python", | |||
|    "pygments_lexer": "ipython3", | |||
|    "version": "3.10.12" | |||
|   }, | |||
|   "orig_nbformat": 4 | |||
|  }, | |||
|  "nbformat": 4, | |||
|  "nbformat_minor": 2 | |||
| } | |||
| @ -0,0 +1,117 @@ | |||
| { | |||
|  "cells": [ | |||
|   { | |||
|    "cell_type": "markdown", | |||
|    "metadata": {}, | |||
|    "source": [ | |||
|     "## Example how to combine shielding with rllibs ppo algorithm." | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 4, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "import gymnasium as gym\n", | |||
|     "\n", | |||
|     "import minigrid\n", | |||
|     "\n", | |||
|     "from ray.tune import register_env\n", | |||
|     "from ray.rllib.algorithms.ppo import PPOConfig\n", | |||
|     "from ray.tune.logger import pretty_print\n", | |||
|     "from ray.rllib.models import ModelCatalog\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "from torch_action_mask_model import TorchActionMaskModel\n", | |||
|     "from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper\n", | |||
|     "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n", | |||
|     "  " | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 5, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "def shielding_env_creater(config):\n", | |||
|     "    name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n", | |||
|     "    framestack = config.get(\"framestack\", 4)\n", | |||
|     "    \n", | |||
|     "    shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", | |||
|     "    \n", | |||
|     "    env = gym.make(name)\n", | |||
|     "    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n", | |||
|     "    env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n", | |||
|     "                                 framestack=framestack)\n", | |||
|     "    \n", | |||
|     "    return env\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "def register_minigrid_shielding_env():\n", | |||
|     "    env_name = \"mini-grid-shielding\"\n", | |||
|     "    register_env(env_name, shielding_env_creater)\n", | |||
|     "    ModelCatalog.register_custom_model(\n", | |||
|     "        \"shielding_model\", \n", | |||
|     "        TorchActionMaskModel)" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": null, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "register_minigrid_shielding_env()\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "config = (PPOConfig()\n", | |||
|     "    .rollouts(num_rollout_workers=1)\n", | |||
|     "    .resources(num_gpus=0)\n", | |||
|     "    .environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\"})\n", | |||
|     "    .framework(\"torch\")\n", | |||
|     "    .rl_module(_enable_rl_module_api = False)\n", | |||
|     "    .training(_enable_learner_api=False ,model={\n", | |||
|     "        \"custom_model\": \"shielding_model\"\n", | |||
|     "    }))\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "algo = (\n", | |||
|     "    config.build()\n", | |||
|     ")\n", | |||
|     "        \n", | |||
|     "for i in range(30):\n", | |||
|     "    result = algo.train()\n", | |||
|     "    print(pretty_print(result))\n", | |||
|     "\n", | |||
|     "    if i % 5 == 0:\n", | |||
|     "        print(\"Saving checkpoint\")\n", | |||
|     "        checkpoint_dir = algo.save()\n", | |||
|     "        print(f\"Checkpoint saved in directory {checkpoint_dir}\")" | |||
|    ] | |||
|   } | |||
|  ], | |||
|  "metadata": { | |||
|   "kernelspec": { | |||
|    "display_name": "env", | |||
|    "language": "python", | |||
|    "name": "python3" | |||
|   }, | |||
|   "language_info": { | |||
|    "codemirror_mode": { | |||
|     "name": "ipython", | |||
|     "version": 3 | |||
|    }, | |||
|    "file_extension": ".py", | |||
|    "mimetype": "text/x-python", | |||
|    "name": "python", | |||
|    "nbconvert_exporter": "python", | |||
|    "pygments_lexer": "ipython3", | |||
|    "version": "3.10.12" | |||
|   }, | |||
|   "orig_nbformat": 4 | |||
|  }, | |||
|  "nbformat": 4, | |||
|  "nbformat_minor": 2 | |||
| } | |||
| @ -0,0 +1,337 @@ | |||
| { | |||
|  "cells": [ | |||
|   { | |||
|    "cell_type": "markdown", | |||
|    "metadata": {}, | |||
|    "source": [ | |||
|     "## Example how to combine shielding with stable baselines contrib maskable ppo algorithm." | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 1, | |||
|    "metadata": {}, | |||
|    "outputs": [ | |||
|     { | |||
|      "name": "stdout", | |||
|      "output_type": "stream", | |||
|      "text": [ | |||
|       "pygame 2.5.1 (SDL 2.28.2, Python 3.10.12)\n", | |||
|       "Hello from the pygame community. https://www.pygame.org/contribute.html\n" | |||
|      ] | |||
|     }, | |||
|     { | |||
|      "name": "stderr", | |||
|      "output_type": "stream", | |||
|      "text": [ | |||
|       "2023-09-08 10:00:46.717621: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", | |||
|       "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", | |||
|       "2023-09-08 10:00:47.771352: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" | |||
|      ] | |||
|     }, | |||
|     { | |||
|      "ename": "ModuleNotFoundError", | |||
|      "evalue": "No module named 'examples'", | |||
|      "output_type": "error", | |||
|      "traceback": [ | |||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
|       "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)", | |||
|       "Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mstable_baselines3\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcommon\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcallbacks\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseCallback\n\u001b[1;32m      7\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mgymnasium\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mgym\u001b[39;00m\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mexamples\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshields\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mrl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshieldhandlers\u001b[39;00m \u001b[39mimport\u001b[39;00m MiniGridShieldHandler, create_shield_query\n\u001b[1;32m     10\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mexamples\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mshields\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mrl\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mwrappers\u001b[39;00m \u001b[39mimport\u001b[39;00m MiniGridSbShieldingWrapper\n", | |||
|       "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'examples'" | |||
|      ] | |||
|     } | |||
|    ], | |||
|    "source": [ | |||
|     "from sb3_contrib import MaskablePPO\n", | |||
|     "from sb3_contrib.common.maskable.evaluation import evaluate_policy\n", | |||
|     "from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy\n", | |||
|     "from sb3_contrib.common.wrappers import ActionMasker\n", | |||
|     "from stable_baselines3.common.callbacks import BaseCallback\n", | |||
|     "\n", | |||
|     "import gymnasium as gym\n", | |||
|     "\n", | |||
|     "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n", | |||
|     "from wrappers import MiniGridSbShieldingWrapper" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 2, | |||
|    "metadata": {}, | |||
|    "outputs": [], | |||
|    "source": [ | |||
|     "def mask_fn(env: gym.Env):\n", | |||
|     "    return env.create_action_mask()" | |||
|    ] | |||
|   }, | |||
|   { | |||
|    "cell_type": "code", | |||
|    "execution_count": 3, | |||
|    "metadata": {}, | |||
|    "outputs": [ | |||
|     { | |||
|      "name": "stdout", | |||
|      "output_type": "stream", | |||
|      "text": [ | |||
|       "Using cpu device\n", | |||
|       "Wrapping the env with a `Monitor` wrapper\n", | |||
|       "Wrapping the env in a DummyVecEnv.\n", | |||
|       "Wrapping the env in a VecTransposeImage.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWGVR  VRVRVRVRVRWG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VRGGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG  VRVRVRVRVRVRWG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWGVRVRVR  VRVRVRWG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR        VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWGVRVRVR  VRVRVRWG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR        VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG          VR  WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG          VRGGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWGVRVRVRVRVRVR  WG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n", | |||
|       "Write to file Grid.shield.\n", | |||
|       "---------------------------------\n", | |||
|       "| rollout/           |          |\n", | |||
|       "|    ep_len_mean     | 283      |\n", | |||
|       "|    ep_rew_mean     | 0.157    |\n", | |||
|       "| time/              |          |\n", | |||
|       "|    fps             | 165      |\n", | |||
|       "|    iterations      | 1        |\n", | |||
|       "|    time_elapsed    | 12       |\n", | |||
|       "|    total_timesteps | 2048     |\n", | |||
|       "---------------------------------\n", | |||
|       "\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Reading   :\tWGXR            WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWGVRVRVRVRVRVR  WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG              WG\n", | |||
|       "Reading   :\tWG            GGWG\n", | |||
|       "Reading   :\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWG              WG\n", | |||
|       "Background:\tWGWGWGWGWGWGWGWGWG\n", | |||
|       "\n" | |||
|      ] | |||
|     }, | |||
|     { | |||
|      "ename": "KeyboardInterrupt", | |||
|      "evalue": "", | |||
|      "output_type": "error", | |||
|      "traceback": [ | |||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
|       "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)", | |||
|       "Cell \u001b[0;32mIn[3], line 9\u001b[0m\n\u001b[1;32m      5\u001b[0m env \u001b[39m=\u001b[39m ActionMasker(env, mask_fn)\n\u001b[1;32m      6\u001b[0m model \u001b[39m=\u001b[39m MaskablePPO(MaskableActorCriticPolicy, env, verbose\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m model\u001b[39m.\u001b[39;49mlearn(\u001b[39m10_000\u001b[39;49m)\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:526\u001b[0m, in \u001b[0;36mMaskablePPO.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, use_masking, progress_bar)\u001b[0m\n\u001b[1;32m    523\u001b[0m callback\u001b[39m.\u001b[39mon_training_start(\u001b[39mlocals\u001b[39m(), \u001b[39mglobals\u001b[39m())\n\u001b[1;32m    525\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps \u001b[39m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 526\u001b[0m     continue_training \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcollect_rollouts(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv, callback, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrollout_buffer, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_steps, use_masking)\n\u001b[1;32m    528\u001b[0m     \u001b[39mif\u001b[39;00m continue_training \u001b[39mis\u001b[39;00m \u001b[39mFalse\u001b[39;00m:\n\u001b[1;32m    529\u001b[0m         \u001b[39mbreak\u001b[39;00m\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/sb3_contrib/ppo_mask/ppo_mask.py:306\u001b[0m, in \u001b[0;36mMaskablePPO.collect_rollouts\u001b[0;34m(self, env, callback, rollout_buffer, n_rollout_steps, use_masking)\u001b[0m\n\u001b[1;32m    303\u001b[0m     actions, values, log_probs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpolicy(obs_tensor, action_masks\u001b[39m=\u001b[39maction_masks)\n\u001b[1;32m    305\u001b[0m actions \u001b[39m=\u001b[39m actions\u001b[39m.\u001b[39mcpu()\u001b[39m.\u001b[39mnumpy()\n\u001b[0;32m--> 306\u001b[0m new_obs, rewards, dones, infos \u001b[39m=\u001b[39m env\u001b[39m.\u001b[39;49mstep(actions)\n\u001b[1;32m    308\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m env\u001b[39m.\u001b[39mnum_envs\n\u001b[1;32m    310\u001b[0m \u001b[39m# Give access to local variables\u001b[39;00m\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/base_vec_env.py:197\u001b[0m, in \u001b[0;36mVecEnv.step\u001b[0;34m(self, actions)\u001b[0m\n\u001b[1;32m    190\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m    191\u001b[0m \u001b[39mStep the environments with the given action\u001b[39;00m\n\u001b[1;32m    192\u001b[0m \n\u001b[1;32m    193\u001b[0m \u001b[39m:param actions: the action\u001b[39;00m\n\u001b[1;32m    194\u001b[0m \u001b[39m:return: observation, reward, done, information\u001b[39;00m\n\u001b[1;32m    195\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m    196\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstep_async(actions)\n\u001b[0;32m--> 197\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstep_wait()\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/vec_transpose.py:95\u001b[0m, in \u001b[0;36mVecTransposeImage.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     94\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstep_wait\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m VecEnvStepReturn:\n\u001b[0;32m---> 95\u001b[0m     observations, rewards, dones, infos \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mvenv\u001b[39m.\u001b[39;49mstep_wait()\n\u001b[1;32m     97\u001b[0m     \u001b[39m# Transpose the terminal observations\u001b[39;00m\n\u001b[1;32m     98\u001b[0m     \u001b[39mfor\u001b[39;00m idx, done \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(dones):\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:70\u001b[0m, in \u001b[0;36mDummyVecEnv.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     67\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_dones[env_idx]:\n\u001b[1;32m     68\u001b[0m         \u001b[39m# save final observation where user can get it, then reset\u001b[39;00m\n\u001b[1;32m     69\u001b[0m         \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_infos[env_idx][\u001b[39m\"\u001b[39m\u001b[39mterminal_observation\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m obs\n\u001b[0;32m---> 70\u001b[0m         obs, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreset_infos[env_idx] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menvs[env_idx]\u001b[39m.\u001b[39;49mreset()\n\u001b[1;32m     71\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_save_obs(env_idx, obs)\n\u001b[1;32m     72\u001b[0m \u001b[39mreturn\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_obs_from_buf(), np\u001b[39m.\u001b[39mcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_rews), np\u001b[39m.\u001b[39mcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_dones), deepcopy(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuf_infos))\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/stable_baselines3/common/monitor.py:83\u001b[0m, in \u001b[0;36mMonitor.reset\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m     81\u001b[0m         \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mExpected you to pass keyword argument \u001b[39m\u001b[39m{\u001b[39;00mkey\u001b[39m}\u001b[39;00m\u001b[39m into reset\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m     82\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcurrent_reset_info[key] \u001b[39m=\u001b[39m value\n\u001b[0;32m---> 83\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv\u001b[39m.\u001b[39;49mreset(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/env/lib/python3.10/site-packages/gymnasium/core.py:414\u001b[0m, in \u001b[0;36mWrapper.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m    410\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mreset\u001b[39m(\n\u001b[1;32m    411\u001b[0m     \u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m, seed: \u001b[39mint\u001b[39m \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m, options: \u001b[39mdict\u001b[39m[\u001b[39mstr\u001b[39m, Any] \u001b[39m|\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m    412\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mtuple\u001b[39m[WrapperObsType, \u001b[39mdict\u001b[39m[\u001b[39mstr\u001b[39m, Any]]:\n\u001b[1;32m    413\u001b[0m \u001b[39m    \u001b[39m\u001b[39m\"\"\"Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 414\u001b[0m     \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv\u001b[39m.\u001b[39;49mreset(seed\u001b[39m=\u001b[39;49mseed, options\u001b[39m=\u001b[39;49moptions)\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/Wrappers.py:222\u001b[0m, in \u001b[0;36mMiniGridSbShieldingWrapper.reset\u001b[0;34m(self, seed, options)\u001b[0m\n\u001b[1;32m    219\u001b[0m obs, infos \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39menv\u001b[39m.\u001b[39mreset(seed\u001b[39m=\u001b[39mseed, options\u001b[39m=\u001b[39moptions)\n\u001b[1;32m    221\u001b[0m keys \u001b[39m=\u001b[39m extract_keys(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39menv)\n\u001b[0;32m--> 222\u001b[0m shield \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshield_creator\u001b[39m.\u001b[39;49mcreate_shield(env\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49menv)\n\u001b[1;32m    224\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mkeys \u001b[39m=\u001b[39m keys\n\u001b[1;32m    225\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mshield \u001b[39m=\u001b[39m shield\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/ShieldHandlers.py:82\u001b[0m, in \u001b[0;36mMiniGridShieldHandler.create_shield\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m     79\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__export_grid_to_text(env)\n\u001b[1;32m     80\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__create_prism()\n\u001b[0;32m---> 82\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__create_shield_dict()\n", | |||
|       "File \u001b[0;32m~/Documents/Projects/tempestpy/examples/shields/rl/ShieldHandlers.py:66\u001b[0m, in \u001b[0;36mMiniGridShieldHandler.__create_shield_dict\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     64\u001b[0m choice \u001b[39m=\u001b[39m shield_scheduler\u001b[39m.\u001b[39mget_choice(stateID)\n\u001b[1;32m     65\u001b[0m choices \u001b[39m=\u001b[39m choice\u001b[39m.\u001b[39mchoice_map\n\u001b[0;32m---> 66\u001b[0m state_valuation \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39;49mstate_valuations\u001b[39m.\u001b[39;49mget_string(stateID)\n\u001b[1;32m     68\u001b[0m actions_to_be_executed \u001b[39m=\u001b[39m [(choice[\u001b[39m1\u001b[39m] ,model\u001b[39m.\u001b[39mchoice_labeling\u001b[39m.\u001b[39mget_labels_of_choice(model\u001b[39m.\u001b[39mget_choice_index(stateID, choice[\u001b[39m1\u001b[39m]))) \u001b[39mfor\u001b[39;00m choice \u001b[39min\u001b[39;00m choices]\n\u001b[1;32m     70\u001b[0m action_dictionary[state_valuation] \u001b[39m=\u001b[39m actions_to_be_executed\n", | |||
|       "\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |||
|      ] | |||
|     } | |||
|    ], | |||
|    "source": [ | |||
|     "shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", | |||
|     "\n", | |||
|     "env = gym.make(\"MiniGrid-LavaCrossingS9N1-v0\", render_mode=\"rgb_array\")\n", | |||
|     "env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)\n", | |||
|     "env = ActionMasker(env, mask_fn)\n", | |||
|     "model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)\n", | |||
|     "\n", | |||
|     "\n", | |||
|     "model.learn(10_000)" | |||
|    ] | |||
|   } | |||
|  ], | |||
|  "metadata": { | |||
|   "kernelspec": { | |||
|    "display_name": "env", | |||
|    "language": "python", | |||
|    "name": "python3" | |||
|   }, | |||
|   "language_info": { | |||
|    "codemirror_mode": { | |||
|     "name": "ipython", | |||
|     "version": 3 | |||
|    }, | |||
|    "file_extension": ".py", | |||
|    "mimetype": "text/x-python", | |||
|    "name": "python", | |||
|    "nbconvert_exporter": "python", | |||
|    "pygments_lexer": "ipython3", | |||
|    "version": "3.10.12" | |||
|   }, | |||
|   "orig_nbformat": 4 | |||
|  }, | |||
|  "nbformat": 4, | |||
|  "nbformat_minor": 2 | |||
| } | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue