@ -45,6 +45,12 @@ 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "import os\n",    "import os\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "class Action():\n", 
		
	
		
			
				    "    def __init__(self, idx, prob=1, labels=[]) -> None:\n", 
		
	
		
			
				    "        self.idx = idx\n", 
		
	
		
			
				    "        self.prob = prob\n", 
		
	
		
			
				    "        self.labels = labels\n", 
		
	
		
			
				    "\n", 
		
	
		
			
				    "class ShieldHandler(ABC):\n",    "class ShieldHandler(ABC):\n", 
		
	
		
			
				    "    def __init__(self) -> None:\n",    "    def __init__(self) -> None:\n", 
		
	
		
			
				    "        pass\n",    "        pass\n", 
		
	
	
		
			
				
					
						
							 
					
					
						
							 
					
					
				 
				@ -98,11 +104,10 @@ 
		
	
		
			
				    "            choices = choice.choice_map\n",    "            choices = choice.choice_map\n", 
		
	
		
			
				    "            state_valuation = model.state_valuations.get_string(stateID)\n",    "            state_valuation = model.state_valuations.get_string(stateID)\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "            actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n", 
		
	
		
			
				    "            actions_to_be_executed = [Action (idx=  choice[1],  prob=choice[0] , labels= model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "            action_dictionary[state_valuation] = actions_to_be_executed\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "        stormpy.shields.export_shield(model, shield, \"Grid.shield\") \n", 
		
	
		
			
				    "            action_dictionary[state_valuation] = actions_to_be_executed \n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "        return action_dictionary\n",    "        return action_dictionary\n", 
		
	
		
			
				    "    \n",    "    \n", 
		
	
	
		
			
				
					
					
					
						
							 
					
				 
				@ -118,12 +123,7 @@ 
		
	
		
			
				    "    coordinates = env.env.agent_pos\n",    "    coordinates = env.env.agent_pos\n", 
		
	
		
			
				    "    view_direction = env.env.agent_dir\n",    "    view_direction = env.env.agent_dir\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "    key_text = \"\"\n", 
		
	
		
			
				    "\n", 
		
	
		
			
				    "    # only support one key for now\n", 
		
	
		
			
				    "   \n", 
		
	
		
			
				    "    #print(F\"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} \")\n", 
		
	
		
			
				    "    cur_pos_str = f\"[{key_text}!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n", 
		
	
		
			
				    "    cur_pos_str = f\"[!AgentDone\\t& xAgent={coordinates[0]}\\t& yAgent={coordinates[1]}\\t& viewAgent={view_direction}]\"\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "    return cur_pos_str\n",    "    return cur_pos_str\n", 
		
	
		
			
				    "    "    "    " 
		
	
	
		
			
				
					
					
					
						
							 
					
				 
				@ -145,6 +145,7 @@ 
		
	
		
			
				   "source": [   "source": [ 
		
	
		
			
				    "import gymnasium as gym\n",    "import gymnasium as gym\n", 
		
	
		
			
				    "import numpy as np\n",    "import numpy as np\n", 
		
	
		
			
				    "import random\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "from minigrid.core.actions import Actions\n",    "from minigrid.core.actions import Actions\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
	
		
			
				
					
					
					
						
							 
					
				 
				@ -158,9 +159,9 @@ 
		
	
		
			
				    "    def __init__(self, env, vector_index, framestack):\n",    "    def __init__(self, env, vector_index, framestack):\n", 
		
	
		
			
				    "        super().__init__(env)\n",    "        super().__init__(env)\n", 
		
	
		
			
				    "        self.framestack = framestack\n",    "        self.framestack = framestack\n", 
		
	
		
			
				    "        # 49=7x7 field of vision; 11 =object types; 6=colors; 3=state types.\n", 
		
	
		
			
				    "        # 49=7x7 field of vision; 16 =object types; 6=colors; 3=state types.\n", 
		
	
		
			
				    "        # +4: Direction.\n",    "        # +4: Direction.\n", 
		
	
		
			
				    "        self.single_frame_dim = 49 * (11  + 6 + 3) + 4\n", 
		
	
		
			
				    "        self.single_frame_dim = 49 * (16  + 6 + 3) + 4\n", 
		
	
		
			
				    "        self.init_x = None\n",    "        self.init_x = None\n", 
		
	
		
			
				    "        self.init_y = None\n",    "        self.init_y = None\n", 
		
	
		
			
				    "        self.x_positions = []\n",    "        self.x_positions = []\n", 
		
	
	
		
			
				
					
						
							 
					
					
						
							 
					
					
				 
				@ -210,7 +211,7 @@ 
		
	
		
			
				    "        image = obs[\"data\"]\n",    "        image = obs[\"data\"]\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
		
			
				    "        # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n",    "        # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.\n", 
		
	
		
			
				    "        objects = one_hot(image[:, :, 0], depth=11 )\n", 
		
	
		
			
				    "        objects = one_hot(image[:, :, 0], depth=16 )\n", 
		
	
		
			
				    "        colors = one_hot(image[:, :, 1], depth=6)\n",    "        colors = one_hot(image[:, :, 1], depth=6)\n", 
		
	
		
			
				    "        states = one_hot(image[:, :, 2], depth=3)\n",    "        states = one_hot(image[:, :, 2], depth=3)\n", 
		
	
		
			
				    "\n",    "\n", 
		
	
	
		
			
				
					
						
							 
					
					
						
							 
					
					
				 
				@ -260,10 +261,13 @@ 
		
	
		
			
				    "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",    "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", 
		
	
		
			
				    "            allowed_actions = self.shield[cur_pos_str]\n",    "            allowed_actions = self.shield[cur_pos_str]\n", 
		
	
		
			
				    "            for allowed_action in allowed_actions:\n",    "            for allowed_action in allowed_actions:\n", 
		
	
		
			
				    "                 index =  get_action_index_mapping(allowed_action[1] ) # Allowed_action is a set\n", 
		
	
		
			
				    "                index =  get_action_index_mapping(allowed_action.labels ) # Allowed_action is a set\n", 
		
	
		
			
				    "                if index is None:\n",    "                if index is None:\n", 
		
	
		
			
				    "                    assert(False)\n",    "                    assert(False)\n", 
		
	
		
			
				    "                 mask[index] = 1.0\n", 
		
	
		
			
				    "                \n", 
		
	
		
			
				    "                allowed =  random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", 
		
	
		
			
				    "                mask[index] = allowed               \n", 
		
	
		
			
				    "                     \n", 
		
	
		
			
				    "        else:\n",    "        else:\n", 
		
	
		
			
				    "            for index, x in enumerate(mask):\n",    "            for index, x in enumerate(mask):\n", 
		
	
		
			
				    "                mask[index] = 1.0\n",    "                mask[index] = 1.0\n", 
		
	
	
		
			
				
					
						
							 
					
					
						
							 
					
					
				 
				@ -336,11 +340,12 @@ 
		
	
		
			
				    "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n",    "        if cur_pos_str in self.shield and self.shield[cur_pos_str]:\n", 
		
	
		
			
				    "            allowed_actions = self.shield[cur_pos_str]\n",    "            allowed_actions = self.shield[cur_pos_str]\n", 
		
	
		
			
				    "            for allowed_action in allowed_actions:\n",    "            for allowed_action in allowed_actions:\n", 
		
	
		
			
				    "                 index =  get_action_index_mapping(allowed_action[1] )\n", 
		
	
		
			
				    "                index =  get_action_index_mapping(allowed_action.labels )\n", 
		
	
		
			
				    "                if index is None:\n",    "                if index is None:\n", 
		
	
		
			
				    "                     assert(False)\n",    "                     assert(False)\n", 
		
	
		
			
				    "                \n",    "                \n", 
		
	
		
			
				    "                 mask[index] = 1.0\n", 
		
	
		
			
				    "                \n", 
		
	
		
			
				    "                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]\n", 
		
	
		
			
				    "        else:\n",    "        else:\n", 
		
	
		
			
				    "            for index, x in enumerate(mask):\n",    "            for index, x in enumerate(mask):\n", 
		
	
		
			
				    "                mask[index] = 1.0\n",    "                mask[index] = 1.0\n",