| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -9,24 +9,25 @@ from minigrid.core.actions import Actions | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import time | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from sb3utils import MiniGridSbShieldingWrapper | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def mask_fn(env: gym.Env): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return env.create_action_mask() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    import argparse | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parse_arguments(argparse) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args.grid_path = F"{args.grid_path}.txt" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args.prism_path = F"{args.prism_path}.prism" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parse_sb3_arguments() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    formula = args.formula | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_value = args.shield_value | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_comparison = args.shield_comparison | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env = gym.make(args.env, render_mode="rgb_array") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env = ActionMasker(env, mask_fn) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -35,14 +36,16 @@ def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    model.learn(steps) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					  #W  mean_reward, std_reward = evaluate_policy(model, model.get_env()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    print("Learning done, hit enter") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    input("") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    vec_env = model.get_env() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    obs = vec_env.reset() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    terminated = truncated = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    while not terminated and not truncated: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        action_masks = None | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        action, _states = model.predict(obs, action_masks=action_masks) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        print(action) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        obs, reward, terminated, truncated, info = env.step(action) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # action, _states = model.predict(obs, deterministic=True) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # obs, rewards, dones, info = vec_env.step(action) | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				
					
  |