| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -14,6 +14,7 @@ from abc import ABC | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import re | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import sys | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import gymnasium as gym | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from minigrid.core.actions import Actions | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from minigrid.core.state import to_state | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -128,10 +129,10 @@ class MiniGridShieldHandler(ShieldHandler): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def create_log_dir(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return F"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return f"{args.log_dir}sh:{args.shielding}-value:{args.shield_value}-comp:{args.shield_comparison}-env:{args.env}-conf:{args.prism_config}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def test_name(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return F"{args.expname}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return f"{args.expname}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def get_allowed_actions_mask(actions): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    action_mask = [0.0] * 3 + [1.0] * 4 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -162,3 +163,19 @@ def common_parser(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--shield_value", default=0.9, type=float) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--shield_comparison", default='absolute', choices=['relative', 'absolute']) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return parser | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class MiniWrapper(gym.Wrapper): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, env): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__(env) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.env = env | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def reset(self, *, seed=None, options=None): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        obs, info = self.env.reset(seed=seed, options=options) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return obs.transpose(1,0,2), info | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def observations(self, obs): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return obs | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def step(self, action): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        obs, reward, terminated, truncated, info = self.env.step(action) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return obs.transpose(1,0,2), reward, terminated, truncated, info |