@ -45,6 +45,12 @@
"\n",
"import os\n",
"\n",
"class Action():\n",
" def __init__(self, idx, prob=1, labels=[]) -> None:\n",
" self.idx = idx\n",
" self.prob = prob\n",
" self.labels = labels\n",
"\n",
"class ShieldHandler(ABC):\n",
" def __init__(self) -> None:\n",
" pass\n",
@ -98,12 +104,11 @@
" choices = choice.choice_map\n",
" state_valuation = model.state_valuations.get_string(stateID)\n",
"\n",
" actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n",
" actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n",
"\n",
"\n",
" action_dictionary[state_valuation] = actions_to_be_executed\n",
"\n",
" stormpy.shields.export_shield(model, shield, \"Grid.shield\")\n",
" \n",
" return action_dictionary\n",
" \n",
" \n",
@ -118,12 +123,7 @@
" coordinates = env.env.agent_pos\n",
" view_direction = env.env.agent_dir\n",
"\n",
" key_text = \"\"\n",
"\n",
" # only support one key for now\n",
" \n",
" #print(F\"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} \")\n",
" cur_pos_str = f\"[{key_text}!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n",
" cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n",
"\n",
" return cur_pos_str\n",
" "
@ -145,6 +145,7 @@
"source": [
"import gymnasium as gym\n",
"import numpy as np\n",
"import random\n",
"\n",
"from minigrid.core.actions import Actions\n",
"\n",
@ -158,9 +159,9 @@
" def __init__(self, env, vector_index, framestack):\n",
" super().__init__(env)\n",
" self.framestack = framestack\n",
" # 49=7x7 field of vision; 11 =object types; 6=colors; 3=state types.\n",
" # 49=7x7 field of vision; 16 =object types; 6=colors; 3=state types.\n",
" # +4: Direction.\n",
" self.single_frame_dim = 49 * (11 + 6 + 3) + 4\n",
" self.single_frame_dim = 49 * (16 + 6 + 3) + 4\n",
" self.init_x = None\n",
" self.init_y = None\n",
" self.x_positions = []\n",
@ -210,7 +211,7 @@
" image = obs[\"data\"]\n",
"\n",
" # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n",
" objects = one_hot(image[:, :, 0], depth=11 )\n",
" objects = one_hot(image[:, :, 0], depth=16 )\n",
" colors = one_hot(image[:, :, 1], depth=6)\n",
" states = one_hot(image[:, :, 2], depth=3)\n",
"\n",
@ -258,12 +259,15 @@
" mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
"\n",
" if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
" allowed_actions = self.shield[cur_pos_str]\n",
" for allowed_action in allowed_actions:\n",
" index = get_action_index_mapping(allowed_action[1]) # Allowed_action is a set\n",
" if index is None:\n",
" assert(False)\n",
" mask[index] = 1.0\n",
" allowed_actions = self.shield[cur_pos_str]\n",
" for allowed_action in allowed_actions:\n",
" index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set\n",
" if index is None:\n",
" assert(False)\n",
" \n",
" allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
" mask[index] = allowed \n",
" \n",
" else:\n",
" for index, x in enumerate(mask):\n",
" mask[index] = 1.0\n",
@ -334,13 +338,14 @@
" mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)\n",
"\n",
" if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",
" allowed_actions = self.shield[cur_pos_str]\n",
" for allowed_action in allowed_actions:\n",
" index = get_action_index_mapping(allowed_action[1] )\n",
" if index is None:\n",
" allowed_actions = self.shield[cur_pos_str]\n",
" for allowed_action in allowed_actions:\n",
" index = get_action_index_mapping(allowed_action.labels )\n",
" if index is None:\n",
" assert(False)\n",
" \n",
" mask[index] = 1.0\n",
" \n",
" \n",
" mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n",
" else:\n",
" for index, x in enumerate(mask):\n",
" mask[index] = 1.0\n",