{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Example how to combine shielding with rllibs dqn algorithm." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "\n", "import minigrid\n", "\n", "from ray import tune, air\n", "from ray.tune import register_env\n", "from ray.rllib.algorithms.dqn.dqn import DQNConfig\n", "from ray.tune.logger import pretty_print\n", "from ray.rllib.models import ModelCatalog\n", "\n", "\n", "from torch_action_mask_model import TorchActionMaskModel\n", "from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper\n", "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def shielding_env_creater(config):\n", " name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n", " framestack = config.get(\"framestack\", 4)\n", " \n", " shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", " \n", " env = gym.make(name)\n", " env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n", " env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n", " framestack=framestack)\n", " \n", " return env\n", "\n", "\n", "def register_minigrid_shielding_env():\n", " env_name = \"mini-grid-shielding\"\n", " register_env(env_name, shielding_env_creater)\n", " ModelCatalog.register_custom_model(\n", " \"shielding_model\", \n", " TorchActionMaskModel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "register_minigrid_shielding_env()\n", "\n", " \n", "config = DQNConfig()\n", "config = config.resources(num_gpus=0)\n", "config = config.rollouts(num_rollout_workers=1)\n", "config = config.environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\" })\n", "config = config.framework(\"torch\")\n", "config = config.rl_module(_enable_rl_module_api = False)\n", "config = config.training(hiddens=[], dueling=False, model={ \n", " \"custom_model\": \"shielding_model\"\n", "})\n", "\n", "tuner = tune.Tuner(\"DQN\",\n", " tune_config=tune.TuneConfig(\n", " metric=\"episode_reward_mean\",\n", " mode=\"max\",\n", " num_samples=1,\n", " \n", " ),\n", " run_config=air.RunConfig(\n", " stop = {\"episode_reward_mean\": 94,\n", " \"timesteps_total\": 12000,\n", " \"training_iteration\": 12}, \n", " checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),\n", " ),\n", " param_space=config,)\n", "\n", "tuner.fit()\n" ] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }