| 
					
					
					
				 | 
				@ -1,19 +1,20 @@ | 
			
		
		
	
		
			
				 | 
				 | 
				from sb3_contrib import MaskablePPO | 
				 | 
				 | 
				from sb3_contrib import MaskablePPO | 
			
		
		
	
		
			
				 | 
				 | 
				from sb3_contrib.common.maskable.evaluation import evaluate_policy | 
				 | 
				 | 
				from sb3_contrib.common.maskable.evaluation import evaluate_policy | 
			
		
		
	
		
			
				 | 
				 | 
				from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				from sb3_contrib.common.wrappers import ActionMasker | 
				 | 
				 | 
				from sb3_contrib.common.wrappers import ActionMasker | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				import gymnasium as gym | 
				 | 
				 | 
				import gymnasium as gym | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				from minigrid.core.actions import Actions | 
				 | 
				 | 
				from minigrid.core.actions import Actions | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				import time | 
				 | 
				 | 
				import time | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				import os | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") | 
			
		
		
	
		
			
				 | 
				 | 
				def mask_fn(env: gym.Env): | 
				 | 
				 | 
				def mask_fn(env: gym.Env): | 
			
		
		
	
		
			
				 | 
				 | 
				    return env.create_action_mask() | 
				 | 
				 | 
				    return env.create_action_mask() | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -27,14 +28,16 @@ def main(): | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) | 
				 | 
				 | 
				    shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) | 
			
		
		
	
		
			
				 | 
				 | 
				    env = gym.make(args.env, render_mode="rgb_array") | 
				 | 
				 | 
				    env = gym.make(args.env, render_mode="rgb_array") | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    env = RGBImgObsWrapper(env) # Get pixel observations | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    env = ImgObsWrapper(env) # Get rid of the 'mission' field | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    env = MiniWrapper(env) | 
			
		
		
	
		
			
				 | 
				 | 
				    env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) | 
				 | 
				 | 
				    env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) | 
			
		
		
	
		
			
				 | 
				 | 
				    env = ActionMasker(env, mask_fn) | 
				 | 
				 | 
				    env = ActionMasker(env, mask_fn) | 
			
		
		
	
		
			
				 | 
				 | 
				    model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args)) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    steps = args.steps | 
				 | 
				 | 
				    steps = args.steps | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    model.learn(steps) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    print("Learning done, hit enter") | 
				 | 
				 | 
				    print("Learning done, hit enter") | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |