|  |  | @ -82,6 +82,7 @@ def parse_arguments(argparse): | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--grid_path", default="Grid.txt") | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--prism_path", default="Grid.PRISM") | 
			
		
	
		
			
				
					|  |  |  |     parser.add_argument("--no_masking", default=False) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     args = parser.parse_args() | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
	
		
			
				
					|  |  | @ -92,14 +93,14 @@ def env_creater_custom(config): | 
			
		
	
		
			
				
					|  |  |  |     # name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | 
			
		
	
		
			
				
					|  |  |  |     # # name = config.get("name", "MiniGrid-Empty-8x8-v0") | 
			
		
	
		
			
				
					|  |  |  |     framestack = config.get("framestack", 4) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     shield = config.get("shield", {}) | 
			
		
	
		
			
				
					|  |  |  |     # env = gym.make(name) | 
			
		
	
		
			
				
					|  |  |  |     # env = ParametricActionsMiniGridEnv(config) | 
			
		
	
		
			
				
					|  |  |  |     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | 
			
		
	
		
			
				
					|  |  |  |     framestack = config.get("framestack", 4) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridEnvWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridEnvWrapper(env, shield=shield) | 
			
		
	
		
			
				
					|  |  |  |     # env = minigrid.wrappers.ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     # env = ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     env = OneHotWrapper(env, | 
			
		
	
	
		
			
				
					|  |  | @ -163,10 +164,21 @@ def create_shield(grid_file, prism_path): | 
			
		
	
		
			
				
					|  |  |  |     assert result.has_scheduler | 
			
		
	
		
			
				
					|  |  |  |     assert result.has_shield | 
			
		
	
		
			
				
					|  |  |  |     shield = result.shield | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     action_dictionary = {} | 
			
		
	
		
			
				
					|  |  |  |     shield_scheduler = shield.construct() | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     for stateID in model.states: | 
			
		
	
		
			
				
					|  |  |  |         choice = shield_scheduler.get_choice(stateID) | 
			
		
	
		
			
				
					|  |  |  |         choices = choice.choice_map | 
			
		
	
		
			
				
					|  |  |  |         state_valuation = model.state_valuations.get_string(stateID) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         action_dictionary[state_valuation] = actions_to_be_executed | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     stormpy.shields.export_shield(model, shield, "Grid.shield") | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     return shield.construct(), model | 
			
		
	
		
			
				
					|  |  |  |     return action_dictionary | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | def export_grid_to_text(env, grid_file): | 
			
		
	
		
			
				
					|  |  |  |     f = open(grid_file, "w") | 
			
		
	
	
		
			
				
					|  |  | @ -195,13 +207,13 @@ def main(): | 
			
		
	
		
			
				
					|  |  |  |     export_grid_to_text(env, grid_file) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     prism_path = args.prism_path | 
			
		
	
		
			
				
					|  |  |  |     shield, model = create_shield(grid_file, prism_path) | 
			
		
	
		
			
				
					|  |  |  |     shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} | 
			
		
	
		
			
				
					|  |  |  |     shield_dict = create_shield(grid_file, prism_path) | 
			
		
	
		
			
				
					|  |  |  |     #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} | 
			
		
	
		
			
				
					|  |  |  |     | 
			
		
	
		
			
				
					|  |  |  |     print(shield_dict) | 
			
		
	
		
			
				
					|  |  |  |     for state_id in model.states: | 
			
		
	
		
			
				
					|  |  |  |         choices = shield.get_choice(state_id) | 
			
		
	
		
			
				
					|  |  |  |         print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") | 
			
		
	
		
			
				
					|  |  |  |     print(F"Shield dictionary {shield_dict}") | 
			
		
	
		
			
				
					|  |  |  |     # for state_id in model.states: | 
			
		
	
		
			
				
					|  |  |  |     #     choices = shield.get_choice(state_id) | 
			
		
	
		
			
				
					|  |  |  |     #     print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |     env_name = "mini-grid" | 
			
		
	
		
			
				
					|  |  |  |     register_env(env_name, env_creater_custom) | 
			
		
	
	
		
			
				
					|  |  | @ -213,14 +225,14 @@ def main(): | 
			
		
	
		
			
				
					|  |  |  |     config = (PPOConfig() | 
			
		
	
		
			
				
					|  |  |  |         .rollouts(num_rollout_workers=1) | 
			
		
	
		
			
				
					|  |  |  |         .resources(num_gpus=0) | 
			
		
	
		
			
				
					|  |  |  |         .environment(env="mini-grid") | 
			
		
	
		
			
				
					|  |  |  |         .environment(env="mini-grid", env_config={"shield": shield_dict }) | 
			
		
	
		
			
				
					|  |  |  |         .framework("torch")        | 
			
		
	
		
			
				
					|  |  |  |         .experimental(_disable_preprocessor_api=False) | 
			
		
	
		
			
				
					|  |  |  |         .callbacks(MyCallbacks) | 
			
		
	
		
			
				
					|  |  |  |         .rl_module(_enable_rl_module_api = False) | 
			
		
	
		
			
				
					|  |  |  |         .training(_enable_learner_api=False ,model={ | 
			
		
	
		
			
				
					|  |  |  |             "custom_model": "pa_model", | 
			
		
	
		
			
				
					|  |  |  |             "custom_model_config" : {"shield": shield_dict, "no_masking": True} | 
			
		
	
		
			
				
					|  |  |  |             "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} | 
			
		
	
		
			
				
					|  |  |  |             # "fcnet_hiddens": [256,256], | 
			
		
	
		
			
				
					|  |  |  |             # "fcnet_activation": "relu", | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
	
		
			
				
					|  |  | @ -231,9 +243,6 @@ def main(): | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         config.build() | 
			
		
	
		
			
				
					|  |  |  |     ) | 
			
		
	
		
			
				
					|  |  |  |     episode_reward = 0 | 
			
		
	
		
			
				
					|  |  |  |     terminated = truncated = False | 
			
		
	
		
			
				
					|  |  |  |     obs, info = env.reset() | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     # while not terminated and not truncated: | 
			
		
	
		
			
				
					|  |  |  |     #     action = algo.compute_single_action(obs) | 
			
		
	
	
		
			
				
					|  |  | 
 |