You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

557 lines
22 KiB

2 months ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "The requisites for applying a shield while training a RL Agent in the Minigrid Environment with PPO Algorithm are:\n",
  8. "\n",
  9. "# Binaries\n",
  10. "- Tempest\n",
  11. "- Minigrid2Prism\n",
  12. "\n",
  13. "\n",
  14. "# Python packages:\n",
  15. "- Tempestpy\n",
  16. "- Minigrid with the printGrid Function\n",
  17. "- ray / rllib"
  18. ]
  19. },
  20. {
  21. "cell_type": "markdown",
  22. "metadata": {},
  23. "source": [
  24. "The shield handler is responsible for creating and querying the shield."
  25. ]
  26. },
  27. {
  28. "cell_type": "code",
  29. "execution_count": null,
  30. "metadata": {},
  31. "outputs": [],
  32. "source": [
  33. "\n",
  34. "import stormpy\n",
  35. "import stormpy.core\n",
  36. "import stormpy.simulator\n",
  37. "\n",
  38. "import stormpy.shields\n",
  39. "import stormpy.logic\n",
  40. "\n",
  41. "import stormpy.examples\n",
  42. "import stormpy.examples.files\n",
  43. "\n",
  44. "from abc import ABC\n",
  45. "\n",
  46. "import os\n",
  47. "\n",
  48. "class Action():\n",
  49. " def __init__(self, idx, prob=1, labels=[]) -> None:\n",
  50. " self.idx = idx\n",
  51. " self.prob = prob\n",
  52. " self.labels = labels\n",
  53. "\n",
  54. "class ShieldHandler(ABC):\n",
  55. " def __init__(self) -> None:\n",
  56. " pass\n",
  57. " def create_shield(self, **kwargs) -> dict:\n",
  58. " pass\n",
  59. "\n",
  60. "class MiniGridShieldHandler(ShieldHandler):\n",
  61. " def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:\n",
  62. " self.grid_file = grid_file\n",
  63. " self.grid_to_prism_path = grid_to_prism_path\n",
  64. " self.prism_path = prism_path\n",
  65. " self.formula = formula\n",
  66. " \n",
  67. " def __export_grid_to_text(self, env):\n",
  68. " f = open(self.grid_file, \"w\")\n",
  69. " f.write(env.printGrid(init=True))\n",
  70. " f.close()\n",
  71. "\n",
  72. " \n",
  73. " def __create_prism(self):\n",
  74. " result = os.system(F\"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}\")\n",
  75. " \n",
  76. " assert result == 0, \"Prism file could not be generated\"\n",
  77. " \n",
  78. " f = open(self.prism_path, \"a\")\n",
  79. " f.write(\"label \\\"AgentIsInLava\\\" = AgentIsInLava;\")\n",
  80. " f.close()\n",
  81. " \n",
  82. " def __create_shield_dict(self):\n",
  83. " program = stormpy.parse_prism_program(self.prism_path)\n",
  84. " shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) \n",
  85. " \n",
  86. " formulas = stormpy.parse_properties_for_prism_program(self.formula, program)\n",
  87. " options = stormpy.BuilderOptions([p.raw_formula for p in formulas])\n",
  88. " options.set_build_state_valuations(True)\n",
  89. " options.set_build_choice_labels(True)\n",
  90. " options.set_build_all_labels()\n",
  91. " model = stormpy.build_sparse_model_with_options(program, options)\n",
  92. " \n",
  93. " result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)\n",
  94. " \n",
  95. " assert result.has_scheduler\n",
  96. " assert result.has_shield\n",
  97. " shield = result.shield\n",
  98. " \n",
  99. " action_dictionary = {}\n",
  100. " shield_scheduler = shield.construct()\n",
  101. " \n",
  102. " for stateID in model.states:\n",
  103. " choice = shield_scheduler.get_choice(stateID)\n",
  104. " choices = choice.choice_map\n",
  105. " state_valuation = model.state_valuations.get_string(stateID)\n",
  106. "\n",
  107. " actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n",
  108. "\n",
  109. "\n",
  110. " action_dictionary[state_valuation] = actions_to_be_executed\n",
  111. "\n",
  112. " return action_dictionary\n",
  113. " \n",
  114. " \n",
  115. " def create_shield(self, **kwargs):\n",
  116. " env = kwargs[\"env\"]\n",
  117. " self.__export_grid_to_text(env)\n",
  118. " self.__create_prism()\n",
  119. " \n",
  120. " return self.__create_shield_dict()\n",
  121. " \n",
  122. "def create_shield_query(env):\n",
  123. " coordinates = env.env.agent_pos\n",
  124. " view_direction = env.env.agent_dir\n",
  125. "\n",
  126. " cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n",
  127. "\n",
  128. " return cur_pos_str\n",
  129. " "
  130. ]
  131. },
  132. {
  133. "cell_type": "markdown",
  134. "metadata": {},
  135. "source": [
  136. "To train a learning algorithm with shielding the allowed actions need to be embedded in the observation. \n",
  137. "This can be done by implementing a gym wrapper handling the action embedding for the enviornment."
  138. ]
  139. },
  140. {
  141. "cell_type": "code",
  142. "execution_count": null,
  143. "metadata": {},
  144. "outputs": [],
  145. "source": [
  146. "import gymnasium as gym\n",
  147. "import numpy as np\n",
  148. "import random\n",
  149. "\n",
  150. "from minigrid.core.actions import Actions\n",
  151. "\n",
  152. "from gymnasium.spaces import Dict, Box\n",
  153. "from collections import deque\n",
  154. "from ray.rllib.utils.numpy import one_hot\n",
  155. "\n",
  156. "from helpers import get_action_index_mapping, extract_keys\n",
  157. "\n",
  158. "class OneHotShieldingWrapper(gym.core.ObservationWrapper):\n",
  159. " def __init__(self, env, vector_index, framestack):\n",
  160. " super().__init__(env)\n",
  161. " self.framestack = framestack\n",
  162. " # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.\n",
  163. " # +4: Direction.\n",
  164. " self.single_frame_dim = 49 * (16 + 6 + 3) + 4\n",
  165. " self.init_x = None\n",
  166. " self.init_y = None\n",
  167. " self.x_positions = []\n",
  168. " self.y_positions = []\n",
  169. " self.x_y_delta_buffer = deque(maxlen=100)\n",
  170. " self.vector_index = vector_index\n",
  171. " self.frame_buffer = deque(maxlen=self.framestack)\n",
  172. " for _ in range(self.framestack):\n",
  173. " self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n",
  174. "\n",
  175. " self.observation_space = Dict(\n",
  176. " {\n",
  177. " \"data\": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),\n",
  178. " \"action_mask\": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),\n",
  179. " }\n",
  180. " )\n",
  181. "\n",
  182. " def observation(self, obs):\n",
  183. " # Debug output: max-x/y positions to watch exploration progress.\n",
  184. " # print(F\"Initial observation in Wrapper {obs}\")\n",
  185. " if self.step_count == 0:\n",
  186. " for _ in range(self.framestack):\n",
  187. " self.frame_buffer.append(np.zeros((self.single_frame_dim,)))\n",
  188. " if self.vector_index == 0:\n",
  189. " if self.x_positions:\n",
  190. " max_diff = max(\n",
  191. " np.sqrt(\n",
  192. " (np.array(self.x_positions) - self.init_x) ** 2\n",
  193. " + (np.array(self.y_positions) - self.init_y) ** 2\n",
  194. " )\n",
  195. " )\n",
  196. " self.x_y_delta_buffer.append(max_diff)\n",
  197. " print(\n",
  198. " \"100-average dist travelled={}\".format(\n",
  199. " np.mean(self.x_y_delta_buffer)\n",
  200. " )\n",
  201. " )\n",
  202. " self.x_positions = []\n",
  203. " self.y_positions = []\n",
  204. " self.init_x = self.agent_pos[0]\n",
  205. " self.init_y = self.agent_pos[1]\n",
  206. "\n",
  207. "\n",
  208. " self.x_positions.append(self.agent_pos[0])\n",
  209. " self.y_positions.append(self.agent_pos[1])\n",
  210. "\n",
  211. " image = obs[\"data\"]\n",
  212. "\n",
  213. " # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n",
  214. " objects = one_hot(image[:, :, 0], depth=16)\n",
  215. " colors = one_hot(image[:, :, 1], depth=6)\n",
  216. " states = one_hot(image[:, :, 2], depth=3)\n",
  217. "\n",
  218. " all_ = np.concatenate([objects, colors, states], -1)\n",
  219. " all_flat = np.reshape(all_, (-1,))\n",
  220. " direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)\n",
  221. " single_frame = np.concatenate([all_flat, direction])\n",
  222. " self.frame_buffer.append(single_frame)\n",
  223. "\n",
  224. " tmp = {\"data\": np.concatenate(self.frame_buffer), \"action_mask\": obs[\"action_mask\"] }\n",
  225. " return tmp\n",
  226. "\n",
  227. "# Environment wrapper handling action embedding in observations\n",
  228. "class MiniGridShieldingWrapper(gym.core.Wrapper):\n",
  229. " def __init__(self, \n",
  230. " env, \n",
  231. " shield_creator : ShieldHandler,\n",
  232. " shield_query_creator,\n",
  233. " create_shield_at_reset=True, \n",
  234. " mask_actions=True):\n",
  235. " super(MiniGridShieldingWrapper, self).__init__(env)\n",
  236. " self.max_available_actions = env.action_space.n\n",
  237. " self.observation_space = Dict(\n",
  238. " {\n",
  239. " \"data\": env.observation_space.spaces[\"image\"],\n",
  240. " \"action_mask\" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),\n",
  241. " }\n",
  242. " )\n",
  243. " self.shield_creator = shield_creator\n",
  244. " self.create_shield_at_reset = create_shield_at_reset\n",
  245. " self.shield = shield_creator.create_shield(env=self.env)\n",
  246. " self.mask_actions = mask_actions\n",
  247. " self.shield_query_creator = shield_query_creator\n",
  248. "\n",
  249. " def create_action_mask(self):\n",
  250. " if not self.mask_actions:\n",
  251. " return np.array([1.0] * self.max_available_actions, dtype=np.int8)\n",
  252. " \n",
  253. " cur_pos_str = self.shield_query_creator(self.env)\n",
  254. " \n",
  255. " # Create the mask\n",
  256. " # If shield restricts action mask only valid with 1.0\n",
  257. " # else set all actions as valid\n",
  258. " allowed_actions = []\n",
  259. " mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
  260. "\n",
  261. " if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
  262. " allowed_actions = self.shield[cur_pos_str]\n",
  263. " for allowed_action in allowed_actions:\n",
  264. " index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set\n",
  265. " if index is None:\n",
  266. " assert(False)\n",
  267. " \n",
  268. " allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
  269. " mask[index] = allowed \n",
  270. " \n",
  271. " else:\n",
  272. " for index, x in enumerate(mask):\n",
  273. " mask[index] = 1.0\n",
  274. " \n",
  275. " front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n",
  276. "\n",
  277. " if front_tile is not None and front_tile.type == \"key\":\n",
  278. " mask[Actions.pickup] = 1.0\n",
  279. " \n",
  280. " if front_tile and front_tile.type == \"door\":\n",
  281. " mask[Actions.toggle] = 1.0\n",
  282. " \n",
  283. " return mask\n",
  284. "\n",
  285. " def reset(self, *, seed=None, options=None):\n",
  286. " obs, infos = self.env.reset(seed=seed, options=options)\n",
  287. " \n",
  288. " if self.create_shield_at_reset and self.mask_actions:\n",
  289. " self.shield = self.shield_creator.create_shield(env=self.env)\n",
  290. " \n",
  291. " self.keys = extract_keys(self.env)\n",
  292. " mask = self.create_action_mask()\n",
  293. " return {\n",
  294. " \"data\": obs[\"image\"],\n",
  295. " \"action_mask\": mask\n",
  296. " }, infos\n",
  297. "\n",
  298. " def step(self, action):\n",
  299. " orig_obs, rew, done, truncated, info = self.env.step(action)\n",
  300. "\n",
  301. " mask = self.create_action_mask()\n",
  302. " obs = {\n",
  303. " \"data\": orig_obs[\"image\"],\n",
  304. " \"action_mask\": mask,\n",
  305. " }\n",
  306. " \n",
  307. " return obs, rew, done, truncated, info\n",
  308. "\n",
  309. "\n",
  310. "# Wrapper to use with a stable baseline algorithm\n",
  311. "class MiniGridSbShieldingWrapper(gym.core.Wrapper):\n",
  312. " def __init__(self, \n",
  313. " env, \n",
  314. " shield_creator : ShieldHandler,\n",
  315. " shield_query_creator,\n",
  316. " create_shield_at_reset = True,\n",
  317. " mask_actions=True,\n",
  318. " ):\n",
  319. " super(MiniGridSbShieldingWrapper, self).__init__(env)\n",
  320. " self.max_available_actions = env.action_space.n\n",
  321. " self.observation_space = env.observation_space.spaces[\"image\"]\n",
  322. " \n",
  323. " self.shield_creator = shield_creator\n",
  324. " self.mask_actions = mask_actions\n",
  325. " self.shield_query_creator = shield_query_creator\n",
  326. "\n",
  327. " def create_action_mask(self):\n",
  328. " if not self.mask_actions:\n",
  329. " return np.array([1.0] * self.max_available_actions, dtype=np.int8)\n",
  330. " \n",
  331. " cur_pos_str = self.shield_query_creator(self.env)\n",
  332. " \n",
  333. " allowed_actions = []\n",
  334. "\n",
  335. " # Create the mask\n",
  336. " # If shield restricts actions, mask only valid actions with 1.0\n",
  337. " # else set all actions valid\n",
  338. " mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
  339. "\n",
  340. " if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
  341. " allowed_actions = self.shield[cur_pos_str]\n",
  342. " for allowed_action in allowed_actions:\n",
  343. " index = get_action_index_mapping(allowed_action.labels)\n",
  344. " if index is None:\n",
  345. " assert(False)\n",
  346. " \n",
  347. " \n",
  348. " mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
  349. " else:\n",
  350. " for index, x in enumerate(mask):\n",
  351. " mask[index] = 1.0\n",
  352. " \n",
  353. " front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])\n",
  354. "\n",
  355. " \n",
  356. " if front_tile and front_tile.type == \"door\":\n",
  357. " mask[Actions.toggle] = 1.0 \n",
  358. " \n",
  359. " return mask \n",
  360. " \n",
  361. "\n",
  362. " def reset(self, *, seed=None, options=None):\n",
  363. " obs, infos = self.env.reset(seed=seed, options=options)\n",
  364. " \n",
  365. " keys = extract_keys(self.env)\n",
  366. " shield = self.shield_creator.create_shield(env=self.env)\n",
  367. " \n",
  368. " self.keys = keys\n",
  369. " self.shield = shield\n",
  370. " return obs[\"image\"], infos\n",
  371. "\n",
  372. " def step(self, action):\n",
  373. " orig_obs, rew, done, truncated, info = self.env.step(action)\n",
  374. " obs = orig_obs[\"image\"]\n",
  375. " \n",
  376. " return obs, rew, done, truncated, info\n",
  377. "\n"
  378. ]
  379. },
  380. {
  381. "cell_type": "markdown",
  382. "metadata": {},
  383. "source": [
  384. "If we want to use rllib algorithms we additionaly need a model which performs the action masking."
  385. ]
  386. },
  387. {
  388. "cell_type": "code",
  389. "execution_count": null,
  390. "metadata": {},
  391. "outputs": [],
  392. "source": [
  393. "from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC\n",
  394. "from ray.rllib.models.torch.torch_modelv2 import TorchModelV2\n",
  395. "from ray.rllib.utils.framework import try_import_torch\n",
  396. "from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX\n",
  397. "\n",
  398. "torch, nn = try_import_torch()\n",
  399. "\n",
  400. "class TorchActionMaskModel(TorchModelV2, nn.Module):\n",
  401. "\n",
  402. " def __init__(\n",
  403. " self,\n",
  404. " obs_space,\n",
  405. " action_space,\n",
  406. " num_outputs,\n",
  407. " model_config,\n",
  408. " name,\n",
  409. " **kwargs,\n",
  410. " ):\n",
  411. " orig_space = getattr(obs_space, \"original_space\", obs_space)\n",
  412. " \n",
  413. " TorchModelV2.__init__(\n",
  414. " self, obs_space, action_space, num_outputs, model_config, name, **kwargs\n",
  415. " )\n",
  416. " nn.Module.__init__(self)\n",
  417. " \n",
  418. " self.count = 0\n",
  419. "\n",
  420. " self.internal_model = TorchFC(\n",
  421. " orig_space[\"data\"],\n",
  422. " action_space,\n",
  423. " num_outputs,\n",
  424. " model_config,\n",
  425. " name + \"_internal\",\n",
  426. " )\n",
  427. " \n",
  428. "\n",
  429. " def forward(self, input_dict, state, seq_lens):\n",
  430. " # Extract the available actions tensor from the observation.\n",
  431. " # Compute the unmasked logits.\n",
  432. " logits, _ = self.internal_model({\"obs\": input_dict[\"obs\"][\"data\"]})\n",
  433. " \n",
  434. " action_mask = input_dict[\"obs\"][\"action_mask\"]\n",
  435. "\n",
  436. " inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)\n",
  437. " masked_logits = logits + inf_mask\n",
  438. "\n",
  439. " # Return masked logits.\n",
  440. " return masked_logits, state\n",
  441. "\n",
  442. " def value_function(self):\n",
  443. " return self.internal_model.value_function()"
  444. ]
  445. },
  446. {
  447. "cell_type": "markdown",
  448. "metadata": {},
  449. "source": [
  450. "Using these components we can now train an rl agent with shielding."
  451. ]
  452. },
  453. {
  454. "cell_type": "code",
  455. "execution_count": null,
  456. "metadata": {},
  457. "outputs": [],
  458. "source": [
  459. "import gymnasium as gym\n",
  460. "import minigrid\n",
  461. "\n",
  462. "from ray import tune, air\n",
  463. "from ray.tune import register_env\n",
  464. "from ray.rllib.algorithms.ppo import PPOConfig\n",
  465. "from ray.tune.logger import pretty_print\n",
  466. "from ray.rllib.models import ModelCatalog\n",
  467. "\n",
  468. "\n",
  469. "def shielding_env_creater(config):\n",
  470. " name = config.get(\"name\", \"MiniGrid-LavaCrossingS9N1-v0\")\n",
  471. " framestack = config.get(\"framestack\", 4)\n",
  472. " \n",
  473. " shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n",
  474. " \n",
  475. " env = gym.make(name)\n",
  476. " env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)\n",
  477. " env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, \"vector_index\") else 0,\n",
  478. " framestack=framestack)\n",
  479. " \n",
  480. " return env\n",
  481. "\n",
  482. "\n",
  483. "def register_minigrid_shielding_env():\n",
  484. " env_name = \"mini-grid-shielding\"\n",
  485. " register_env(env_name, shielding_env_creater)\n",
  486. " ModelCatalog.register_custom_model(\n",
  487. " \"shielding_model\", \n",
  488. " TorchActionMaskModel)\n",
  489. "\n",
  490. "register_minigrid_shielding_env()\n",
  491. "\n",
  492. "\n",
  493. "config = (PPOConfig()\n",
  494. " .rollouts(num_rollout_workers=1)\n",
  495. " .resources(num_gpus=0)\n",
  496. " .environment(env=\"mini-grid-shielding\", env_config={\"name\": \"MiniGrid-LavaCrossingS9N1-v0\"})\n",
  497. " .framework(\"torch\")\n",
  498. " .rl_module(_enable_rl_module_api = False)\n",
  499. " .training(_enable_learner_api=False ,model={\n",
  500. " \"custom_model\": \"shielding_model\"\n",
  501. " }))\n",
  502. "\n",
  503. "tuner = tune.Tuner(\"PPO\",\n",
  504. " tune_config=tune.TuneConfig(\n",
  505. " metric=\"episode_reward_mean\",\n",
  506. " mode=\"max\",\n",
  507. " num_samples=1,\n",
  508. " \n",
  509. " ),\n",
  510. " run_config=air.RunConfig(\n",
  511. " stop = {\"episode_reward_mean\": 94,\n",
  512. " \"timesteps_total\": 12000,\n",
  513. " \"training_iteration\": 12}, \n",
  514. " checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),\n",
  515. " ),\n",
  516. " param_space=config,)\n",
  517. "\n",
  518. "results = tuner.fit()\n",
  519. "best_result = results.get_best_result()\n",
  520. "\n",
  521. "import pprint\n",
  522. "\n",
  523. "metrics_to_print = [\n",
  524. "\"episode_reward_mean\",\n",
  525. "\"episode_reward_max\",\n",
  526. "\"episode_reward_min\",\n",
  527. "\"episode_len_mean\",\n",
  528. "]\n",
  529. "pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})\n",
  530. "\n",
  531. " "
  532. ]
  533. }
  534. ],
  535. "metadata": {
  536. "kernelspec": {
  537. "display_name": "env",
  538. "language": "python",
  539. "name": "python3"
  540. },
  541. "language_info": {
  542. "codemirror_mode": {
  543. "name": "ipython",
  544. "version": 3
  545. },
  546. "file_extension": ".py",
  547. "mimetype": "text/x-python",
  548. "name": "python",
  549. "nbconvert_exporter": "python",
  550. "pygments_lexer": "ipython3",
  551. "version": "3.10.12"
  552. },
  553. "orig_nbformat": 4
  554. },
  555. "nbformat": 4,
  556. "nbformat_minor": 2
  557. }