You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

78 lines
2.1 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "## Example how to combine shielding with stable baselines contrib maskable ppo algorithm."
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": null,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "from sb3_contrib import MaskablePPO\n",
  17. "from sb3_contrib.common.maskable.evaluation import evaluate_policy\n",
  18. "from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy\n",
  19. "from sb3_contrib.common.wrappers import ActionMasker\n",
  20. "from stable_baselines3.common.callbacks import BaseCallback\n",
  21. "\n",
  22. "import gymnasium as gym\n",
  23. "\n",
  24. "from shieldhandlers import MiniGridShieldHandler, create_shield_query\n",
  25. "from wrappers import MiniGridSbShieldingWrapper"
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": null,
  31. "metadata": {},
  32. "outputs": [],
  33. "source": [
  34. "def mask_fn(env: gym.Env):\n",
  35. " return env.create_action_mask()"
  36. ]
  37. },
  38. {
  39. "cell_type": "code",
  40. "execution_count": null,
  41. "metadata": {},
  42. "outputs": [],
  43. "source": [
  44. "shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n",
  45. "\n",
  46. "env = gym.make(\"MiniGrid-LavaCrossingS9N1-v0\", render_mode=\"rgb_array\")\n",
  47. "env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)\n",
  48. "env = ActionMasker(env, mask_fn)\n",
  49. "model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)\n",
  50. "\n",
  51. "\n",
  52. "model.learn(10_000)"
  53. ]
  54. }
  55. ],
  56. "metadata": {
  57. "kernelspec": {
  58. "display_name": "env",
  59. "language": "python",
  60. "name": "python3"
  61. },
  62. "language_info": {
  63. "codemirror_mode": {
  64. "name": "ipython",
  65. "version": 3
  66. },
  67. "file_extension": ".py",
  68. "mimetype": "text/x-python",
  69. "name": "python",
  70. "nbconvert_exporter": "python",
  71. "pygments_lexer": "ipython3",
  72. "version": "3.10.12"
  73. },
  74. "orig_nbformat": 4
  75. },
  76. "nbformat": 4,
  77. "nbformat_minor": 2
  78. }