| 
					
					
						
							
						
					
					
				 | 
				@ -36,7 +36,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def step(self, action): | 
				 | 
				 | 
				    def step(self, action): | 
			
		
		
	
		
			
				 | 
				 | 
				        obs, rew, done, truncated, info = self.env.step(action) | 
				 | 
				 | 
				        obs, rew, done, truncated, info = self.env.step(action) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        info["no_shield_action"] = not self.shield.has_key(self.env.get_symbolic_state()) | 
			
		
		
	
		
			
				 | 
				 | 
				        return obs, rew, done, truncated, info | 
				 | 
				 | 
				        return obs, rew, done, truncated, info | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				def parse_sb3_arguments(): | 
				 | 
				 | 
				def parse_sb3_arguments(): | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
						
							
						
					
					
				 | 
				@ -104,6 +104,7 @@ class InfoCallback(BaseCallback): | 
			
		
		
	
		
			
				 | 
				 | 
				        self.sum_collisions = 0 | 
				 | 
				 | 
				        self.sum_collisions = 0 | 
			
		
		
	
		
			
				 | 
				 | 
				        self.sum_opened_door = 0 | 
				 | 
				 | 
				        self.sum_opened_door = 0 | 
			
		
		
	
		
			
				 | 
				 | 
				        self.sum_picked_up = 0 | 
				 | 
				 | 
				        self.sum_picked_up = 0 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.no_shield_action = 0 | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    def _on_step(self) -> bool: | 
				 | 
				 | 
				    def _on_step(self) -> bool: | 
			
		
		
	
		
			
				 | 
				 | 
				        infos = self.locals["infos"][0] | 
				 | 
				 | 
				        infos = self.locals["infos"][0] | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -125,4 +126,8 @@ class InfoCallback(BaseCallback): | 
			
		
		
	
		
			
				 | 
				 | 
				            if infos["picked_up"]: | 
				 | 
				 | 
				            if infos["picked_up"]: | 
			
		
		
	
		
			
				 | 
				 | 
				                self.sum_picked_up += 1 | 
				 | 
				 | 
				                self.sum_picked_up += 1 | 
			
		
		
	
		
			
				 | 
				 | 
				            self.logger.record("info/sum_picked_up", self.sum_picked_up) | 
				 | 
				 | 
				            self.logger.record("info/sum_picked_up", self.sum_picked_up) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if "no_shield_action" in infos: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            if infos["no_shield_action"]: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                self.no_shield_action += 1 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self.logger.record("info/no_shield_action", self.no_shield_action) | 
			
		
		
	
		
			
				 | 
				 | 
				        return True | 
				 | 
				 | 
				        return True |