From c0fc671870c926a8c9e535223d19ae222c23a6e9 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Fri, 15 Sep 2023 12:26:46 +0200 Subject: [PATCH] probability evaluation in jupyter notebooks --- examples/shields/rl/dqn_rllib.ipynb | 2 +- examples/shields/rl/ppo_rllib.ipynb | 2 +- examples/shields/rl/tutorial.ipynb | 53 ++++++++++++++++------------- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/examples/shields/rl/dqn_rllib.ipynb b/examples/shields/rl/dqn_rllib.ipynb index 20dc44e..bf5152b 100644 --- a/examples/shields/rl/dqn_rllib.ipynb +++ b/examples/shields/rl/dqn_rllib.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/examples/shields/rl/ppo_rllib.ipynb b/examples/shields/rl/ppo_rllib.ipynb index 0ff4d10..9b6676a 100644 --- a/examples/shields/rl/ppo_rllib.ipynb +++ b/examples/shields/rl/ppo_rllib.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ diff --git a/examples/shields/rl/tutorial.ipynb b/examples/shields/rl/tutorial.ipynb index 039d6bb..f82bac2 100644 --- a/examples/shields/rl/tutorial.ipynb +++ b/examples/shields/rl/tutorial.ipynb @@ -45,6 +45,12 @@ "\n", "import os\n", "\n", + "class Action():\n", + " def __init__(self, idx, prob=1, labels=[]) -> None:\n", + " self.idx = idx\n", + " self.prob = prob\n", + " self.labels = labels\n", + "\n", "class ShieldHandler(ABC):\n", " def __init__(self) -> None:\n", " pass\n", @@ -98,12 +104,11 @@ " choices = choice.choice_map\n", " state_valuation = model.state_valuations.get_string(stateID)\n", "\n", - " actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n", + " actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n", + "\n", "\n", " action_dictionary[state_valuation] = actions_to_be_executed\n", "\n", - " stormpy.shields.export_shield(model, shield, \"Grid.shield\")\n", - " \n", " return action_dictionary\n", " \n", " \n", @@ -118,12 +123,7 @@ " coordinates = env.env.agent_pos\n", " view_direction = env.env.agent_dir\n", "\n", - " key_text = \"\"\n", - "\n", - " # only support one key for now\n", - " \n", - " #print(F\"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} \")\n", - " cur_pos_str = f\"[{key_text}!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n", + " cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n", "\n", " return cur_pos_str\n", " " @@ -145,6 +145,7 @@ "source": [ "import gymnasium as gym\n", "import numpy as np\n", + "import random\n", "\n", "from minigrid.core.actions import Actions\n", "\n", @@ -158,9 +159,9 @@ " def __init__(self, env, vector_index, framestack):\n", " super().__init__(env)\n", " self.framestack = framestack\n", - " # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.\n", + " # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.\n", " # +4: Direction.\n", - " self.single_frame_dim = 49 * (11 + 6 + 3) + 4\n", + " self.single_frame_dim = 49 * (16 + 6 + 3) + 4\n", " self.init_x = None\n", " self.init_y = None\n", " self.x_positions = []\n", @@ -210,7 +211,7 @@ " image = obs[\"data\"]\n", "\n", " # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n", - " objects = one_hot(image[:, :, 0], depth=11)\n", + " objects = one_hot(image[:, :, 0], depth=16)\n", " colors = one_hot(image[:, :, 1], depth=6)\n", " states = one_hot(image[:, :, 2], depth=3)\n", "\n", @@ -258,12 +259,15 @@ " mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n", "\n", " if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", - " allowed_actions = self.shield[cur_pos_str]\n", - " for allowed_action in allowed_actions:\n", - " index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set\n", - " if index is None:\n", - " assert(False)\n", - " mask[index] = 1.0\n", + " allowed_actions = self.shield[cur_pos_str]\n", + " for allowed_action in allowed_actions:\n", + " index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set\n", + " if index is None:\n", + " assert(False)\n", + " \n", + " allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", + " mask[index] = allowed \n", + " \n", " else:\n", " for index, x in enumerate(mask):\n", " mask[index] = 1.0\n", @@ -334,13 +338,14 @@ " mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n", "\n", " if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", - " allowed_actions = self.shield[cur_pos_str]\n", - " for allowed_action in allowed_actions:\n", - " index = get_action_index_mapping(allowed_action[1])\n", - " if index is None:\n", + " allowed_actions = self.shield[cur_pos_str]\n", + " for allowed_action in allowed_actions:\n", + " index = get_action_index_mapping(allowed_action.labels)\n", + " if index is None:\n", " assert(False)\n", - " \n", - " mask[index] = 1.0\n", + " \n", + " \n", + " mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", " else:\n", " for index, x in enumerate(mask):\n", " mask[index] = 1.0\n",