You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
			
		
		
		
		
			
		
			
				
					
					
						
							557 lines
						
					
					
						
							22 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							557 lines
						
					
					
						
							22 KiB
						
					
					
				| { | |
|  "cells": [ | |
|   { | |
|    "cell_type": "markdown", | |
|    "metadata": {}, | |
|    "source": [ | |
|     "The requisites for applying a shield while training a RL Agent in the Minigrid Environment with PPO Algorithm are:\n", | |
|     "\n", | |
|     "# Binaries\n", | |
|     "- Tempest\n", | |
|     "- Minigrid2Prism\n", | |
|     "\n", | |
|     "\n", | |
|     "# Python packages:\n", | |
|     "- Tempestpy\n", | |
|     "- Minigrid with the printGrid Function\n", | |
|     "- ray / rllib" | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "markdown", | |
|    "metadata": {}, | |
|    "source": [ | |
|     "The shield handler is responsible for creating and querying the shield." | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "code", | |
|    "execution_count": null, | |
|    "metadata": {}, | |
|    "outputs": [], | |
|    "source": [ | |
|     "\n", | |
|     "import stormpy\n", | |
|     "import stormpy.core\n", | |
|     "import stormpy.simulator\n", | |
|     "\n", | |
|     "import stormpy.shields\n", | |
|     "import stormpy.logic\n", | |
|     "\n", | |
|     "import stormpy.examples\n", | |
|     "import stormpy.examples.files\n", | |
|     "\n", | |
|     "from abc import ABC\n", | |
|     "\n", | |
|     "import os\n", | |
|     "\n", | |
|     "class Action():\n", | |
|     "    def __init__(self, idx, prob=1, labels=[]) -> None:\n", | |
|     "        self.idx = idx\n", | |
|     "        self.prob = prob\n", | |
|     "        self.labels = labels\n", | |
|     "\n", | |
|     "class ShieldHandler(ABC):\n", | |
|     "    def __init__(self) -> None:\n", | |
|     "        pass\n", | |
|     "    def create_shield(self, **kwargs) -> dict:\n", | |
|     "        pass\n", | |
|     "\n", | |
|     "class MiniGridShieldHandler(ShieldHandler):\n", | |
|     "    def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:\n", | |
|     "        self.grid_file = grid_file\n", | |
|     "        self.grid_to_prism_path = grid_to_prism_path\n", | |
|     "        self.prism_path = prism_path\n", | |
|     "        self.formula = formula\n", | |
|     "    \n", | |
|     "    def __export_grid_to_text(self, env):\n", | |
|     "        f = open(self.grid_file, \"w\")\n", | |
|     "        f.write(env.printGrid(init=True))\n", | |
|     "        f.close()\n", | |
|     "\n", | |
|     "    \n", | |
|     "    def __create_prism(self):\n", | |
|     "        result = os.system(F\"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}\")\n", | |
|     "    \n", | |
|     "        assert result == 0, \"Prism file could not be generated\"\n", | |
|     "    \n", | |
|     "        f = open(self.prism_path, \"a\")\n", | |
|     "        f.write(\"label \\\"AgentIsInLava\\\" = AgentIsInLava;\")\n", | |
|     "        f.close()\n", | |
|     "        \n", | |
|     "    def __create_shield_dict(self):\n", | |
|     "        program = stormpy.parse_prism_program(self.prism_path)\n", | |
|     "        shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) \n", | |
|     "        \n", | |
|     "        formulas = stormpy.parse_properties_for_prism_program(self.formula, program)\n", | |
|     "        options = stormpy.BuilderOptions([p.raw_formula for p in formulas])\n", | |
|     "        options.set_build_state_valuations(True)\n", | |
|     "        options.set_build_choice_labels(True)\n", | |
|     "        options.set_build_all_labels()\n", | |
|     "        model = stormpy.build_sparse_model_with_options(program, options)\n", | |
|     "        \n", | |
|     "        result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)\n", | |
|     "        \n", | |
|     "        assert result.has_scheduler\n", | |
|     "        assert result.has_shield\n", | |
|     "        shield = result.shield\n", | |
|     "        \n", | |
|     "        action_dictionary = {}\n", | |
|     "        shield_scheduler = shield.construct()\n", | |
|     "        \n", | |
|     "        for stateID in model.states:\n", | |
|     "            choice = shield_scheduler.get_choice(stateID)\n", | |
|     "            choices = choice.choice_map\n", | |
|     "            state_valuation = model.state_valuations.get_string(stateID)\n", | |
|     "\n", | |
|     "            actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n", | |
|     "\n", | |
|     "\n", | |
|     "            action_dictionary[state_valuation] = actions_to_be_executed\n", | |
|     "\n", | |
|     "        return action_dictionary\n", | |
|     "    \n", | |
|     "    \n", | |
|     "    def create_shield(self, **kwargs):\n", | |
|     "        env = kwargs[\"env\"]\n", | |
|     "        self.__export_grid_to_text(env)\n", | |
|     "        self.__create_prism()\n", | |
|     "       \n", | |
|     "        return self.__create_shield_dict()\n", | |
|     "        \n", | |
|     "def create_shield_query(env):\n", | |
|     "    coordinates = env.env.agent_pos\n", | |
|     "    view_direction = env.env.agent_dir\n", | |
|     "\n", | |
|     "    cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n", | |
|     "\n", | |
|     "    return cur_pos_str\n", | |
|     "    " | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "markdown", | |
|    "metadata": {}, | |
|    "source": [ | |
|     "To train a learning algorithm with shielding the allowed actions need to be embedded in the observation. \n", | |
|     "This can be done by implementing a gym wrapper handling the action embedding for the enviornment." | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "code", | |
|    "execution_count": null, | |
|    "metadata": {}, | |
|    "outputs": [], | |
|    "source": [ | |
|     "import gymnasium as gym\n", | |
|     "import numpy as np\n", | |
|     "import random\n", | |
|     "\n", | |
|     "from minigrid.core.actions import Actions\n", | |
|     "\n", | |
|     "from gymnasium.spaces import Dict, Box\n", | |
|     "from collections import deque\n", | |
|     "from ray.rllib.utils.numpy import one_hot\n", | |
|     "\n", | |
|     "from helpers import get_action_index_mapping, extract_keys\n", | |
|     "\n", | |
|     "class OneHotShieldingWrapper(gym.core.ObservationWrapper):\n", | |
|     "    def __init__(self, env, vector_index, framestack):\n", | |
|     "        super().__init__(env)\n", | |
|     "        self.framestack = framestack\n", | |
|     "        # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.\n", | |
|     "        # +4: Direction.\n", | |
|     "        self.single_frame_dim = 49 * (16 + 6 + 3) + 4\n", | |
|     "        self.init_x = None\n", | |
|     "        self.init_y = None\n", | |
|     "        self.x_positions = []\n", | |
|     "        self.y_positions = []\n", | |
|     "        self.x_y_delta_buffer = deque(maxlen=100)\n", | |
|     "        self.vector_index = vector_index\n", | |
|     "        self.frame_buffer = deque(maxlen=self.framestack)\n", | |
|     "        for _ in range(self.framestack):\n", | |
|     "            self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n", | |
|     "\n", | |
|     "        self.observation_space = Dict(\n", | |
|     "            {\n", | |
|     "                \"data\": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),\n", | |
|     "                \"action_mask\": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),\n", | |
|     "            }\n", | |
|     "            )\n", | |
|     "\n", | |
|     "    def observation(self, obs):\n", | |
|     "        # Debug output: max-x/y positions to watch exploration progress.\n", | |
|     "        # print(F\"Initial observation in Wrapper {obs}\")\n", | |
|     "        if self.step_count == 0:\n", | |
|     "            for _ in range(self.framestack):\n", | |
|     "                self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n", | |
|     "            if self.vector_index == 0:\n", | |
|     "                if self.x_positions:\n", | |
|     "                    max_diff = max(\n", | |
|     "                        np.sqrt(\n", | |
|     "                            (np.array(self.x_positions) - self.init_x) ** 2\n", | |
|     "                            + (np.array(self.y_positions) - self.init_y) ** 2\n", | |
|     "                        )\n", | |
|     "                    )\n", | |
|     "                    self.x_y_delta_buffer.append(max_diff)\n", | |
|     "                    print(\n", | |
|     "                        \"100-average dist travelled={}\".format(\n", | |
|     "                            np.mean(self.x_y_delta_buffer)\n", | |
|     "                        )\n", | |
|     "                    )\n", | |
|     "                    self.x_positions = []\n", | |
|     "                    self.y_positions = []\n", | |
|     "                self.init_x = self.agent_pos[0]\n", | |
|     "                self.init_y = self.agent_pos[1]\n", | |
|     "\n", | |
|     "\n", | |
|     "        self.x_positions.append(self.agent_pos[0])\n", | |
|     "        self.y_positions.append(self.agent_pos[1])\n", | |
|     "\n", | |
|     "        image = obs[\"data\"]\n", | |
|     "\n", | |
|     "        # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n", | |
|     "        objects = one_hot(image[:, :, 0], depth=16)\n", | |
|     "        colors = one_hot(image[:, :, 1], depth=6)\n", | |
|     "        states = one_hot(image[:, :, 2], depth=3)\n", | |
|     "\n", | |
|     "        all_ = np.concatenate([objects, colors, states], -1)\n", | |
|     "        all_flat = np.reshape(all_, (-1,))\n", | |
|     "        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)\n", | |
|     "        single_frame = np.concatenate([all_flat, direction])\n", | |
|     "        self.frame_buffer.append(single_frame)\n", | |
|     "\n", | |
|     "        tmp = {\"data\": np.concatenate(self.frame_buffer), \"action_mask\": obs[\"action_mask\"] }\n", | |
|     "        return tmp\n", | |
|     "\n", | |
|     "# Environment wrapper handling action embedding in observations\n", | |
|     "class MiniGridShieldingWrapper(gym.core.Wrapper):\n", | |
|     "    def __init__(self, \n", | |
|     "                 env, \n", | |
|     "                shield_creator : ShieldHandler,\n", | |
|     "                shield_query_creator,\n", | |
|     "                create_shield_at_reset=True,    \n", | |
|     "                mask_actions=True):\n", | |
|     "        super(MiniGridShieldingWrapper, self).__init__(env)\n", | |
|     "        self.max_available_actions = env.action_space.n\n", | |
|     "        self.observation_space = Dict(\n", | |
|     "            {\n", | |
|     "                \"data\": env.observation_space.spaces[\"image\"],\n", | |
|     "                \"action_mask\" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),\n", | |
|     "            }\n", | |
|     "        )\n", | |
|     "        self.shield_creator = shield_creator\n", | |
|     "        self.create_shield_at_reset = create_shield_at_reset\n", | |
|     "        self.shield = shield_creator.create_shield(env=self.env)\n", | |
|     "        self.mask_actions = mask_actions\n", | |
|     "        self.shield_query_creator = shield_query_creator\n", | |
|     "\n", | |
|     "    def create_action_mask(self):\n", | |
|     "        if not self.mask_actions:\n", | |
|     "            return np.array([1.0] * self.max_available_actions, dtype=np.int8)\n", | |
|     "        \n", | |
|     "        cur_pos_str = self.shield_query_creator(self.env)\n", | |
|     "      \n", | |
|     "        # Create the mask\n", | |
|     "        # If shield restricts action mask only valid with 1.0\n", | |
|     "        # else set all actions as valid\n", | |
|     "        allowed_actions = []\n", | |
|     "        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n", | |
|     "\n", | |
|     "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", | |
|     "            allowed_actions = self.shield[cur_pos_str]\n", | |
|     "            for allowed_action in allowed_actions:\n", | |
|     "                index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set\n", | |
|     "                if index is None:\n", | |
|     "                    assert(False)\n", | |
|     "                \n", | |
|     "                allowed =  random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", | |
|     "                mask[index] = allowed               \n", | |
|     "                     \n", | |
|     "        else:\n", | |
|     "            for index, x in enumerate(mask):\n", | |
|     "                mask[index] = 1.0\n", | |
|     "        \n", | |
|     "        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n", | |
|     "\n", | |
|     "        if front_tile is not None and front_tile.type == \"key\":\n", | |
|     "            mask[Actions.pickup] = 1.0\n", | |
|     "        \n", | |
|     "        if front_tile and front_tile.type == \"door\":\n", | |
|     "            mask[Actions.toggle] = 1.0\n", | |
|     "            \n", | |
|     "        return mask\n", | |
|     "\n", | |
|     "    def reset(self, *, seed=None, options=None):\n", | |
|     "        obs, infos = self.env.reset(seed=seed, options=options)\n", | |
|     "        \n", | |
|     "        if self.create_shield_at_reset and self.mask_actions:\n", | |
|     "            self.shield = self.shield_creator.create_shield(env=self.env)\n", | |
|     "        \n", | |
|     "        self.keys = extract_keys(self.env)\n", | |
|     "        mask = self.create_action_mask()\n", | |
|     "        return {\n", | |
|     "            \"data\": obs[\"image\"],\n", | |
|     "            \"action_mask\": mask\n", | |
|     "        }, infos\n", | |
|     "\n", | |
|     "    def step(self, action):\n", | |
|     "        orig_obs, rew, done, truncated, info = self.env.step(action)\n", | |
|     "\n", | |
|     "        mask = self.create_action_mask()\n", | |
|     "        obs = {\n", | |
|     "            \"data\": orig_obs[\"image\"],\n", | |
|     "            \"action_mask\": mask,\n", | |
|     "        }\n", | |
|     "        \n", | |
|     "        return obs, rew, done, truncated, info\n", | |
|     "\n", | |
|     "\n", | |
|     "# Wrapper to use with a stable baseline algorithm\n", | |
|     "class MiniGridSbShieldingWrapper(gym.core.Wrapper):\n", | |
|     "    def __init__(self, \n", | |
|     "                 env, \n", | |
|     "                 shield_creator : ShieldHandler,\n", | |
|     "                 shield_query_creator,\n", | |
|     "                 create_shield_at_reset = True,\n", | |
|     "                 mask_actions=True,\n", | |
|     "                 ):\n", | |
|     "        super(MiniGridSbShieldingWrapper, self).__init__(env)\n", | |
|     "        self.max_available_actions = env.action_space.n\n", | |
|     "        self.observation_space = env.observation_space.spaces[\"image\"]\n", | |
|     "        \n", | |
|     "        self.shield_creator = shield_creator\n", | |
|     "        self.mask_actions = mask_actions\n", | |
|     "        self.shield_query_creator = shield_query_creator\n", | |
|     "\n", | |
|     "    def create_action_mask(self):\n", | |
|     "        if not self.mask_actions:\n", | |
|     "            return  np.array([1.0] * self.max_available_actions, dtype=np.int8)\n", | |
|     "               \n", | |
|     "        cur_pos_str = self.shield_query_creator(self.env)\n", | |
|     "        \n", | |
|     "        allowed_actions = []\n", | |
|     "\n", | |
|     "        # Create the mask\n", | |
|     "        # If shield restricts actions, mask only valid actions with 1.0\n", | |
|     "        # else set all actions valid\n", | |
|     "        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n", | |
|     "\n", | |
|     "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", | |
|     "            allowed_actions = self.shield[cur_pos_str]\n", | |
|     "            for allowed_action in allowed_actions:\n", | |
|     "                index =  get_action_index_mapping(allowed_action.labels)\n", | |
|     "                if index is None:\n", | |
|     "                     assert(False)\n", | |
|     "                \n", | |
|     "                \n", | |
|     "                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", | |
|     "        else:\n", | |
|     "            for index, x in enumerate(mask):\n", | |
|     "                mask[index] = 1.0\n", | |
|     "        \n", | |
|     "        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n", | |
|     "\n", | |
|     "            \n", | |
|     "        if front_tile and front_tile.type == \"door\":\n", | |
|     "            mask[Actions.toggle] = 1.0            \n", | |
|     "            \n", | |
|     "        return mask  \n", | |
|     "    \n", | |
|     "\n", | |
|     "    def reset(self, *, seed=None, options=None):\n", | |
|     "        obs, infos = self.env.reset(seed=seed, options=options)\n", | |
|     "      \n", | |
|     "        keys = extract_keys(self.env)\n", | |
|     "        shield = self.shield_creator.create_shield(env=self.env)\n", | |
|     "        \n", | |
|     "        self.keys = keys\n", | |
|     "        self.shield = shield\n", | |
|     "        return obs[\"image\"], infos\n", | |
|     "\n", | |
|     "    def step(self, action):\n", | |
|     "        orig_obs, rew, done, truncated, info = self.env.step(action)\n", | |
|     "        obs = orig_obs[\"image\"]\n", | |
|     "        \n", | |
|     "        return obs, rew, done, truncated, info\n", | |
|     "\n" | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "markdown", | |
|    "metadata": {}, | |
|    "source": [ | |
|     "If we want to use rllib algorithms we additionaly need a model which performs the action masking." | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "code", | |
|    "execution_count": null, | |
|    "metadata": {}, | |
|    "outputs": [], | |
|    "source": [ | |
|     "from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC\n", | |
|     "from ray.rllib.models.torch.torch_modelv2 import TorchModelV2\n", | |
|     "from ray.rllib.utils.framework import try_import_torch\n", | |
|     "from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX\n", | |
|     "\n", | |
|     "torch, nn = try_import_torch()\n", | |
|     "\n", | |
|     "class TorchActionMaskModel(TorchModelV2, nn.Module):\n", | |
|     "\n", | |
|     "    def __init__(\n", | |
|     "        self,\n", | |
|     "        obs_space,\n", | |
|     "        action_space,\n", | |
|     "        num_outputs,\n", | |
|     "        model_config,\n", | |
|     "        name,\n", | |
|     "        **kwargs,\n", | |
|     "    ):\n", | |
|     "        orig_space = getattr(obs_space, \"original_space\", obs_space)\n", | |
|     "        \n", | |
|     "        TorchModelV2.__init__(\n", | |
|     "            self, obs_space, action_space, num_outputs, model_config, name, **kwargs\n", | |
|     "        )\n", | |
|     "        nn.Module.__init__(self)\n", | |
|     "        \n", | |
|     "        self.count = 0\n", | |
|     "\n", | |
|     "        self.internal_model = TorchFC(\n", | |
|     "            orig_space[\"data\"],\n", | |
|     "            action_space,\n", | |
|     "            num_outputs,\n", | |
|     "            model_config,\n", | |
|     "            name + \"_internal\",\n", | |
|     "        )\n", | |
|     "        \n", | |
|     "\n", | |
|     "    def forward(self, input_dict, state, seq_lens):\n", | |
|     "        # Extract the available actions tensor from the observation.\n", | |
|     "        # Compute the unmasked logits.\n", | |
|     "        logits, _ = self.internal_model({\"obs\": input_dict[\"obs\"][\"data\"]})\n", | |
|     "   \n", | |
|     "        action_mask = input_dict[\"obs\"][\"action_mask\"]\n", | |
|     "\n", | |
|     "        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)\n", | |
|     "        masked_logits = logits + inf_mask\n", | |
|     "\n", | |
|     "        # Return masked logits.\n", | |
|     "        return masked_logits, state\n", | |
|     "\n", | |
|     "    def value_function(self):\n", | |
|     "        return self.internal_model.value_function()" | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "markdown", | |
|    "metadata": {}, | |
|    "source": [ | |
|     "Using these components we can now train an rl agent with shielding." | |
|    ] | |
|   }, | |
|   { | |
|    "cell_type": "code", | |
|    "execution_count": null, | |
|    "metadata": {}, | |
|    "outputs": [], | |
|    "source": [ | |
|     "import gymnasium as gym\n", | |
|     "import minigrid\n", | |
|     "\n", | |
|     "from ray import tune, air\n", | |
|     "from ray.tune import register_env\n", | |
|     "from ray.rllib.algorithms.ppo import PPOConfig\n", | |
|     "from ray.tune.logger import pretty_print\n", | |
|     "from ray.rllib.models import ModelCatalog\n", | |
|     "\n", | |
|     "\n", | |
|     "def shielding_env_creater(config):\n", | |
|     "    name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n", | |
|     "    framestack = config.get(\"framestack\", 4)\n", | |
|     "    \n", | |
|     "    shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n", | |
|     "    \n", | |
|     "    env = gym.make(name)\n", | |
|     "    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n", | |
|     "    env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n", | |
|     "                                 framestack=framestack)\n", | |
|     "    \n", | |
|     "    return env\n", | |
|     "\n", | |
|     "\n", | |
|     "def register_minigrid_shielding_env():\n", | |
|     "    env_name = \"mini-grid-shielding\"\n", | |
|     "    register_env(env_name, shielding_env_creater)\n", | |
|     "    ModelCatalog.register_custom_model(\n", | |
|     "        \"shielding_model\", \n", | |
|     "        TorchActionMaskModel)\n", | |
|     "\n", | |
|     "register_minigrid_shielding_env()\n", | |
|     "\n", | |
|     "\n", | |
|     "config = (PPOConfig()\n", | |
|     "    .rollouts(num_rollout_workers=1)\n", | |
|     "    .resources(num_gpus=0)\n", | |
|     "    .environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\"})\n", | |
|     "    .framework(\"torch\")\n", | |
|     "    .rl_module(_enable_rl_module_api = False)\n", | |
|     "    .training(_enable_learner_api=False ,model={\n", | |
|     "        \"custom_model\": \"shielding_model\"\n", | |
|     "    }))\n", | |
|     "\n", | |
|     "tuner = tune.Tuner(\"PPO\",\n", | |
|     "                    tune_config=tune.TuneConfig(\n", | |
|     "                        metric=\"episode_reward_mean\",\n", | |
|     "                        mode=\"max\",\n", | |
|     "                        num_samples=1,\n", | |
|     "                        \n", | |
|     "                    ),\n", | |
|     "                    run_config=air.RunConfig(\n", | |
|     "                            stop = {\"episode_reward_mean\": 94,\n", | |
|     "                                    \"timesteps_total\": 12000,\n", | |
|     "                                    \"training_iteration\": 12}, \n", | |
|     "                            checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),\n", | |
|     "                    ),\n", | |
|     "                    param_space=config,)\n", | |
|     "\n", | |
|     "results = tuner.fit()\n", | |
|     "best_result = results.get_best_result()\n", | |
|     "\n", | |
|     "import pprint\n", | |
|     "\n", | |
|     "metrics_to_print = [\n", | |
|     "\"episode_reward_mean\",\n", | |
|     "\"episode_reward_max\",\n", | |
|     "\"episode_reward_min\",\n", | |
|     "\"episode_len_mean\",\n", | |
|     "]\n", | |
|     "pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n", | |
|     "\n", | |
|     "      " | |
|    ] | |
|   } | |
|  ], | |
|  "metadata": { | |
|   "kernelspec": { | |
|    "display_name": "env", | |
|    "language": "python", | |
|    "name": "python3" | |
|   }, | |
|   "language_info": { | |
|    "codemirror_mode": { | |
|     "name": "ipython", | |
|     "version": 3 | |
|    }, | |
|    "file_extension": ".py", | |
|    "mimetype": "text/x-python", | |
|    "name": "python", | |
|    "nbconvert_exporter": "python", | |
|    "pygments_lexer": "ipython3", | |
|    "version": "3.10.12" | |
|   }, | |
|   "orig_nbformat": 4 | |
|  }, | |
|  "nbformat": 4, | |
|  "nbformat_minor": 2 | |
| }
 |