{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example how to combine shielding with rllibs dqn algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "\n", "import minigrid\n", "\n", "from ray.tune import register_env\n", "from ray.rllib.algorithms.dqn.dqn import DQNConfig\n", "from ray.tune.logger import pretty_print\n", "from ray.rllib.models import ModelCatalog\n", "\n", "\n", "from torch_action_mask_model import TorchActionMaskModel\n", "from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper\n", "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n", " " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def shielding_env_creater(config):\n", " name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n", " framestack = config.get(\"framestack\", 4)\n", " \n", " shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", " \n", " env = gym.make(name)\n", " env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n", " env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n", " framestack=framestack)\n", " \n", " return env\n", "\n", "\n", "def register_minigrid_shielding_env():\n", " env_name = \"mini-grid-shielding\"\n", " register_env(env_name, shielding_env_creater)\n", " ModelCatalog.register_custom_model(\n", " \"shielding_model\", \n", " TorchActionMaskModel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "register_minigrid_shielding_env()\n", "\n", " \n", "config = DQNConfig()\n", "config = config.resources(num_gpus=0)\n", "config = config.rollouts(num_rollout_workers=1)\n", "config = config.environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\" })\n", "config = config.framework(\"torch\")\n", "config = config.rl_module(_enable_rl_module_api = False)\n", "config = config.training(hiddens=[], dueling=False, model={ \n", " \"custom_model\": \"shielding_model\"\n", "})\n", " \n", "algo = (\n", " config.build()\n", ")\n", " \n", "for i in range(30):\n", " result = algo.train()\n", " print(pretty_print(result))\n", "\n", " if i % 5 == 0:\n", " print(\"Saving checkpoint\")\n", " checkpoint_dir = algo.save()\n", " print(f\"Checkpoint saved in directory {checkpoint_dir}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }