You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
557 lines
22 KiB
557 lines
22 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The requisites for applying a shield while training a RL Agent in the Minigrid Environment with PPO Algorithm are:\n",
|
|
"\n",
|
|
"# Binaries\n",
|
|
"- Tempest\n",
|
|
"- Minigrid2Prism\n",
|
|
"\n",
|
|
"\n",
|
|
"# Python packages:\n",
|
|
"- Tempestpy\n",
|
|
"- Minigrid with the printGrid Function\n",
|
|
"- ray / rllib"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"The shield handler is responsible for creating and querying the shield."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"import stormpy\n",
|
|
"import stormpy.core\n",
|
|
"import stormpy.simulator\n",
|
|
"\n",
|
|
"import stormpy.shields\n",
|
|
"import stormpy.logic\n",
|
|
"\n",
|
|
"import stormpy.examples\n",
|
|
"import stormpy.examples.files\n",
|
|
"\n",
|
|
"from abc import ABC\n",
|
|
"\n",
|
|
"import os\n",
|
|
"\n",
|
|
"class Action():\n",
|
|
" def __init__(self, idx, prob=1, labels=[]) -> None:\n",
|
|
" self.idx = idx\n",
|
|
" self.prob = prob\n",
|
|
" self.labels = labels\n",
|
|
"\n",
|
|
"class ShieldHandler(ABC):\n",
|
|
" def __init__(self) -> None:\n",
|
|
" pass\n",
|
|
" def create_shield(self, **kwargs) -> dict:\n",
|
|
" pass\n",
|
|
"\n",
|
|
"class MiniGridShieldHandler(ShieldHandler):\n",
|
|
" def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:\n",
|
|
" self.grid_file = grid_file\n",
|
|
" self.grid_to_prism_path = grid_to_prism_path\n",
|
|
" self.prism_path = prism_path\n",
|
|
" self.formula = formula\n",
|
|
" \n",
|
|
" def __export_grid_to_text(self, env):\n",
|
|
" f = open(self.grid_file, \"w\")\n",
|
|
" f.write(env.printGrid(init=True))\n",
|
|
" f.close()\n",
|
|
"\n",
|
|
" \n",
|
|
" def __create_prism(self):\n",
|
|
" result = os.system(F\"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}\")\n",
|
|
" \n",
|
|
" assert result == 0, \"Prism file could not be generated\"\n",
|
|
" \n",
|
|
" f = open(self.prism_path, \"a\")\n",
|
|
" f.write(\"label \\\"AgentIsInLava\\\" = AgentIsInLava;\")\n",
|
|
" f.close()\n",
|
|
" \n",
|
|
" def __create_shield_dict(self):\n",
|
|
" program = stormpy.parse_prism_program(self.prism_path)\n",
|
|
" shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) \n",
|
|
" \n",
|
|
" formulas = stormpy.parse_properties_for_prism_program(self.formula, program)\n",
|
|
" options = stormpy.BuilderOptions([p.raw_formula for p in formulas])\n",
|
|
" options.set_build_state_valuations(True)\n",
|
|
" options.set_build_choice_labels(True)\n",
|
|
" options.set_build_all_labels()\n",
|
|
" model = stormpy.build_sparse_model_with_options(program, options)\n",
|
|
" \n",
|
|
" result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)\n",
|
|
" \n",
|
|
" assert result.has_scheduler\n",
|
|
" assert result.has_shield\n",
|
|
" shield = result.shield\n",
|
|
" \n",
|
|
" action_dictionary = {}\n",
|
|
" shield_scheduler = shield.construct()\n",
|
|
" \n",
|
|
" for stateID in model.states:\n",
|
|
" choice = shield_scheduler.get_choice(stateID)\n",
|
|
" choices = choice.choice_map\n",
|
|
" state_valuation = model.state_valuations.get_string(stateID)\n",
|
|
"\n",
|
|
" actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n",
|
|
"\n",
|
|
"\n",
|
|
" action_dictionary[state_valuation] = actions_to_be_executed\n",
|
|
"\n",
|
|
" return action_dictionary\n",
|
|
" \n",
|
|
" \n",
|
|
" def create_shield(self, **kwargs):\n",
|
|
" env = kwargs[\"env\"]\n",
|
|
" self.__export_grid_to_text(env)\n",
|
|
" self.__create_prism()\n",
|
|
" \n",
|
|
" return self.__create_shield_dict()\n",
|
|
" \n",
|
|
"def create_shield_query(env):\n",
|
|
" coordinates = env.env.agent_pos\n",
|
|
" view_direction = env.env.agent_dir\n",
|
|
"\n",
|
|
" cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n",
|
|
"\n",
|
|
" return cur_pos_str\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"To train a learning algorithm with shielding the allowed actions need to be embedded in the observation. \n",
|
|
"This can be done by implementing a gym wrapper handling the action embedding for the enviornment."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"import numpy as np\n",
|
|
"import random\n",
|
|
"\n",
|
|
"from minigrid.core.actions import Actions\n",
|
|
"\n",
|
|
"from gymnasium.spaces import Dict, Box\n",
|
|
"from collections import deque\n",
|
|
"from ray.rllib.utils.numpy import one_hot\n",
|
|
"\n",
|
|
"from helpers import get_action_index_mapping, extract_keys\n",
|
|
"\n",
|
|
"class OneHotShieldingWrapper(gym.core.ObservationWrapper):\n",
|
|
" def __init__(self, env, vector_index, framestack):\n",
|
|
" super().__init__(env)\n",
|
|
" self.framestack = framestack\n",
|
|
" # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.\n",
|
|
" # +4: Direction.\n",
|
|
" self.single_frame_dim = 49 * (16 + 6 + 3) + 4\n",
|
|
" self.init_x = None\n",
|
|
" self.init_y = None\n",
|
|
" self.x_positions = []\n",
|
|
" self.y_positions = []\n",
|
|
" self.x_y_delta_buffer = deque(maxlen=100)\n",
|
|
" self.vector_index = vector_index\n",
|
|
" self.frame_buffer = deque(maxlen=self.framestack)\n",
|
|
" for _ in range(self.framestack):\n",
|
|
" self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n",
|
|
"\n",
|
|
" self.observation_space = Dict(\n",
|
|
" {\n",
|
|
" \"data\": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),\n",
|
|
" \"action_mask\": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),\n",
|
|
" }\n",
|
|
" )\n",
|
|
"\n",
|
|
" def observation(self, obs):\n",
|
|
" # Debug output: max-x/y positions to watch exploration progress.\n",
|
|
" # print(F\"Initial observation in Wrapper {obs}\")\n",
|
|
" if self.step_count == 0:\n",
|
|
" for _ in range(self.framestack):\n",
|
|
" self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n",
|
|
" if self.vector_index == 0:\n",
|
|
" if self.x_positions:\n",
|
|
" max_diff = max(\n",
|
|
" np.sqrt(\n",
|
|
" (np.array(self.x_positions) - self.init_x) ** 2\n",
|
|
" + (np.array(self.y_positions) - self.init_y) ** 2\n",
|
|
" )\n",
|
|
" )\n",
|
|
" self.x_y_delta_buffer.append(max_diff)\n",
|
|
" print(\n",
|
|
" \"100-average dist travelled={}\".format(\n",
|
|
" np.mean(self.x_y_delta_buffer)\n",
|
|
" )\n",
|
|
" )\n",
|
|
" self.x_positions = []\n",
|
|
" self.y_positions = []\n",
|
|
" self.init_x = self.agent_pos[0]\n",
|
|
" self.init_y = self.agent_pos[1]\n",
|
|
"\n",
|
|
"\n",
|
|
" self.x_positions.append(self.agent_pos[0])\n",
|
|
" self.y_positions.append(self.agent_pos[1])\n",
|
|
"\n",
|
|
" image = obs[\"data\"]\n",
|
|
"\n",
|
|
" # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n",
|
|
" objects = one_hot(image[:, :, 0], depth=16)\n",
|
|
" colors = one_hot(image[:, :, 1], depth=6)\n",
|
|
" states = one_hot(image[:, :, 2], depth=3)\n",
|
|
"\n",
|
|
" all_ = np.concatenate([objects, colors, states], -1)\n",
|
|
" all_flat = np.reshape(all_, (-1,))\n",
|
|
" direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)\n",
|
|
" single_frame = np.concatenate([all_flat, direction])\n",
|
|
" self.frame_buffer.append(single_frame)\n",
|
|
"\n",
|
|
" tmp = {\"data\": np.concatenate(self.frame_buffer), \"action_mask\": obs[\"action_mask\"] }\n",
|
|
" return tmp\n",
|
|
"\n",
|
|
"# Environment wrapper handling action embedding in observations\n",
|
|
"class MiniGridShieldingWrapper(gym.core.Wrapper):\n",
|
|
" def __init__(self, \n",
|
|
" env, \n",
|
|
" shield_creator : ShieldHandler,\n",
|
|
" shield_query_creator,\n",
|
|
" create_shield_at_reset=True, \n",
|
|
" mask_actions=True):\n",
|
|
" super(MiniGridShieldingWrapper, self).__init__(env)\n",
|
|
" self.max_available_actions = env.action_space.n\n",
|
|
" self.observation_space = Dict(\n",
|
|
" {\n",
|
|
" \"data\": env.observation_space.spaces[\"image\"],\n",
|
|
" \"action_mask\" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),\n",
|
|
" }\n",
|
|
" )\n",
|
|
" self.shield_creator = shield_creator\n",
|
|
" self.create_shield_at_reset = create_shield_at_reset\n",
|
|
" self.shield = shield_creator.create_shield(env=self.env)\n",
|
|
" self.mask_actions = mask_actions\n",
|
|
" self.shield_query_creator = shield_query_creator\n",
|
|
"\n",
|
|
" def create_action_mask(self):\n",
|
|
" if not self.mask_actions:\n",
|
|
" return np.array([1.0] * self.max_available_actions, dtype=np.int8)\n",
|
|
" \n",
|
|
" cur_pos_str = self.shield_query_creator(self.env)\n",
|
|
" \n",
|
|
" # Create the mask\n",
|
|
" # If shield restricts action mask only valid with 1.0\n",
|
|
" # else set all actions as valid\n",
|
|
" allowed_actions = []\n",
|
|
" mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
|
|
"\n",
|
|
" if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
|
|
" allowed_actions = self.shield[cur_pos_str]\n",
|
|
" for allowed_action in allowed_actions:\n",
|
|
" index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set\n",
|
|
" if index is None:\n",
|
|
" assert(False)\n",
|
|
" \n",
|
|
" allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
|
|
" mask[index] = allowed \n",
|
|
" \n",
|
|
" else:\n",
|
|
" for index, x in enumerate(mask):\n",
|
|
" mask[index] = 1.0\n",
|
|
" \n",
|
|
" front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n",
|
|
"\n",
|
|
" if front_tile is not None and front_tile.type == \"key\":\n",
|
|
" mask[Actions.pickup] = 1.0\n",
|
|
" \n",
|
|
" if front_tile and front_tile.type == \"door\":\n",
|
|
" mask[Actions.toggle] = 1.0\n",
|
|
" \n",
|
|
" return mask\n",
|
|
"\n",
|
|
" def reset(self, *, seed=None, options=None):\n",
|
|
" obs, infos = self.env.reset(seed=seed, options=options)\n",
|
|
" \n",
|
|
" if self.create_shield_at_reset and self.mask_actions:\n",
|
|
" self.shield = self.shield_creator.create_shield(env=self.env)\n",
|
|
" \n",
|
|
" self.keys = extract_keys(self.env)\n",
|
|
" mask = self.create_action_mask()\n",
|
|
" return {\n",
|
|
" \"data\": obs[\"image\"],\n",
|
|
" \"action_mask\": mask\n",
|
|
" }, infos\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
" orig_obs, rew, done, truncated, info = self.env.step(action)\n",
|
|
"\n",
|
|
" mask = self.create_action_mask()\n",
|
|
" obs = {\n",
|
|
" \"data\": orig_obs[\"image\"],\n",
|
|
" \"action_mask\": mask,\n",
|
|
" }\n",
|
|
" \n",
|
|
" return obs, rew, done, truncated, info\n",
|
|
"\n",
|
|
"\n",
|
|
"# Wrapper to use with a stable baseline algorithm\n",
|
|
"class MiniGridSbShieldingWrapper(gym.core.Wrapper):\n",
|
|
" def __init__(self, \n",
|
|
" env, \n",
|
|
" shield_creator : ShieldHandler,\n",
|
|
" shield_query_creator,\n",
|
|
" create_shield_at_reset = True,\n",
|
|
" mask_actions=True,\n",
|
|
" ):\n",
|
|
" super(MiniGridSbShieldingWrapper, self).__init__(env)\n",
|
|
" self.max_available_actions = env.action_space.n\n",
|
|
" self.observation_space = env.observation_space.spaces[\"image\"]\n",
|
|
" \n",
|
|
" self.shield_creator = shield_creator\n",
|
|
" self.mask_actions = mask_actions\n",
|
|
" self.shield_query_creator = shield_query_creator\n",
|
|
"\n",
|
|
" def create_action_mask(self):\n",
|
|
" if not self.mask_actions:\n",
|
|
" return np.array([1.0] * self.max_available_actions, dtype=np.int8)\n",
|
|
" \n",
|
|
" cur_pos_str = self.shield_query_creator(self.env)\n",
|
|
" \n",
|
|
" allowed_actions = []\n",
|
|
"\n",
|
|
" # Create the mask\n",
|
|
" # If shield restricts actions, mask only valid actions with 1.0\n",
|
|
" # else set all actions valid\n",
|
|
" mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
|
|
"\n",
|
|
" if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
|
|
" allowed_actions = self.shield[cur_pos_str]\n",
|
|
" for allowed_action in allowed_actions:\n",
|
|
" index = get_action_index_mapping(allowed_action.labels)\n",
|
|
" if index is None:\n",
|
|
" assert(False)\n",
|
|
" \n",
|
|
" \n",
|
|
" mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
|
|
" else:\n",
|
|
" for index, x in enumerate(mask):\n",
|
|
" mask[index] = 1.0\n",
|
|
" \n",
|
|
" front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n",
|
|
"\n",
|
|
" \n",
|
|
" if front_tile and front_tile.type == \"door\":\n",
|
|
" mask[Actions.toggle] = 1.0 \n",
|
|
" \n",
|
|
" return mask \n",
|
|
" \n",
|
|
"\n",
|
|
" def reset(self, *, seed=None, options=None):\n",
|
|
" obs, infos = self.env.reset(seed=seed, options=options)\n",
|
|
" \n",
|
|
" keys = extract_keys(self.env)\n",
|
|
" shield = self.shield_creator.create_shield(env=self.env)\n",
|
|
" \n",
|
|
" self.keys = keys\n",
|
|
" self.shield = shield\n",
|
|
" return obs[\"image\"], infos\n",
|
|
"\n",
|
|
" def step(self, action):\n",
|
|
" orig_obs, rew, done, truncated, info = self.env.step(action)\n",
|
|
" obs = orig_obs[\"image\"]\n",
|
|
" \n",
|
|
" return obs, rew, done, truncated, info\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"If we want to use rllib algorithms we additionaly need a model which performs the action masking."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC\n",
|
|
"from ray.rllib.models.torch.torch_modelv2 import TorchModelV2\n",
|
|
"from ray.rllib.utils.framework import try_import_torch\n",
|
|
"from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX\n",
|
|
"\n",
|
|
"torch, nn = try_import_torch()\n",
|
|
"\n",
|
|
"class TorchActionMaskModel(TorchModelV2, nn.Module):\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" obs_space,\n",
|
|
" action_space,\n",
|
|
" num_outputs,\n",
|
|
" model_config,\n",
|
|
" name,\n",
|
|
" **kwargs,\n",
|
|
" ):\n",
|
|
" orig_space = getattr(obs_space, \"original_space\", obs_space)\n",
|
|
" \n",
|
|
" TorchModelV2.__init__(\n",
|
|
" self, obs_space, action_space, num_outputs, model_config, name, **kwargs\n",
|
|
" )\n",
|
|
" nn.Module.__init__(self)\n",
|
|
" \n",
|
|
" self.count = 0\n",
|
|
"\n",
|
|
" self.internal_model = TorchFC(\n",
|
|
" orig_space[\"data\"],\n",
|
|
" action_space,\n",
|
|
" num_outputs,\n",
|
|
" model_config,\n",
|
|
" name + \"_internal\",\n",
|
|
" )\n",
|
|
" \n",
|
|
"\n",
|
|
" def forward(self, input_dict, state, seq_lens):\n",
|
|
" # Extract the available actions tensor from the observation.\n",
|
|
" # Compute the unmasked logits.\n",
|
|
" logits, _ = self.internal_model({\"obs\": input_dict[\"obs\"][\"data\"]})\n",
|
|
" \n",
|
|
" action_mask = input_dict[\"obs\"][\"action_mask\"]\n",
|
|
"\n",
|
|
" inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)\n",
|
|
" masked_logits = logits + inf_mask\n",
|
|
"\n",
|
|
" # Return masked logits.\n",
|
|
" return masked_logits, state\n",
|
|
"\n",
|
|
" def value_function(self):\n",
|
|
" return self.internal_model.value_function()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Using these components we can now train an rl agent with shielding."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"import minigrid\n",
|
|
"\n",
|
|
"from ray import tune, air\n",
|
|
"from ray.tune import register_env\n",
|
|
"from ray.rllib.algorithms.ppo import PPOConfig\n",
|
|
"from ray.tune.logger import pretty_print\n",
|
|
"from ray.rllib.models import ModelCatalog\n",
|
|
"\n",
|
|
"\n",
|
|
"def shielding_env_creater(config):\n",
|
|
" name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n",
|
|
" framestack = config.get(\"framestack\", 4)\n",
|
|
" \n",
|
|
" shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n",
|
|
" \n",
|
|
" env = gym.make(name)\n",
|
|
" env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n",
|
|
" env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n",
|
|
" framestack=framestack)\n",
|
|
" \n",
|
|
" return env\n",
|
|
"\n",
|
|
"\n",
|
|
"def register_minigrid_shielding_env():\n",
|
|
" env_name = \"mini-grid-shielding\"\n",
|
|
" register_env(env_name, shielding_env_creater)\n",
|
|
" ModelCatalog.register_custom_model(\n",
|
|
" \"shielding_model\", \n",
|
|
" TorchActionMaskModel)\n",
|
|
"\n",
|
|
"register_minigrid_shielding_env()\n",
|
|
"\n",
|
|
"\n",
|
|
"config = (PPOConfig()\n",
|
|
" .rollouts(num_rollout_workers=1)\n",
|
|
" .resources(num_gpus=0)\n",
|
|
" .environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\"})\n",
|
|
" .framework(\"torch\")\n",
|
|
" .rl_module(_enable_rl_module_api = False)\n",
|
|
" .training(_enable_learner_api=False ,model={\n",
|
|
" \"custom_model\": \"shielding_model\"\n",
|
|
" }))\n",
|
|
"\n",
|
|
"tuner = tune.Tuner(\"PPO\",\n",
|
|
" tune_config=tune.TuneConfig(\n",
|
|
" metric=\"episode_reward_mean\",\n",
|
|
" mode=\"max\",\n",
|
|
" num_samples=1,\n",
|
|
" \n",
|
|
" ),\n",
|
|
" run_config=air.RunConfig(\n",
|
|
" stop = {\"episode_reward_mean\": 94,\n",
|
|
" \"timesteps_total\": 12000,\n",
|
|
" \"training_iteration\": 12}, \n",
|
|
" checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),\n",
|
|
" ),\n",
|
|
" param_space=config,)\n",
|
|
"\n",
|
|
"results = tuner.fit()\n",
|
|
"best_result = results.get_best_result()\n",
|
|
"\n",
|
|
"import pprint\n",
|
|
"\n",
|
|
"metrics_to_print = [\n",
|
|
"\"episode_reward_mean\",\n",
|
|
"\"episode_reward_max\",\n",
|
|
"\"episode_reward_min\",\n",
|
|
"\"episode_len_mean\",\n",
|
|
"]\n",
|
|
"pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n",
|
|
"\n",
|
|
" "
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|