You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
2.1 KiB
78 lines
2.1 KiB
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Example how to combine shielding with stable baselines contrib maskable ppo algorithm."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sb3_contrib import MaskablePPO\n",
|
|
"from sb3_contrib.common.maskable.evaluation import evaluate_policy\n",
|
|
"from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy\n",
|
|
"from sb3_contrib.common.wrappers import ActionMasker\n",
|
|
"from stable_baselines3.common.callbacks import BaseCallback\n",
|
|
"\n",
|
|
"import gymnasium as gym\n",
|
|
"\n",
|
|
"from shieldhandlers import MiniGridShieldHandler, create_shield_query\n",
|
|
"from wrappers import MiniGridSbShieldingWrapper"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def mask_fn(env: gym.Env):\n",
|
|
" return env.create_action_mask()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"shield_creator = MiniGridShieldHandler(\"grid.txt\", \"./main\", \"grid.prism\", \"Pmax=? [G !\\\"AgentIsInLavaAndNotDone\\\"]\")\n",
|
|
"\n",
|
|
"env = gym.make(\"MiniGrid-LavaCrossingS9N1-v0\", render_mode=\"rgb_array\")\n",
|
|
"env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)\n",
|
|
"env = ActionMasker(env, mask_fn)\n",
|
|
"model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1)\n",
|
|
"\n",
|
|
"\n",
|
|
"model.learn(10_000)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "env",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.12"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|