| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -3,6 +3,8 @@ import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import random | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from utils import MiniGridShieldHandler, common_parser | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from stable_baselines3.common.logger import Image | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class MiniGridSbShieldingWrapper(gym.core.Wrapper): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -43,3 +45,40 @@ def parse_sb3_arguments(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parser.parse_args() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return args | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class ImageRecorderCallback(BaseCallback): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, verbose=0): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__(verbose) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _on_training_start(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        image = self.training_env.render(mode="rgb_array") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _on_step(self): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					class InfoCallback(BaseCallback): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    Custom callback for plotting additional values in tensorboard. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def __init__(self, verbose=0): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        super().__init__(verbose) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sum_goal = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sum_lava = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.sum_collisions = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def _on_step(self) -> bool: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        infos = self.locals["infos"][0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if infos["reached_goal"]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.sum_goal += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if infos["ran_into_lava"]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.sum_lava += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.logger.record("info/sum_reached_goal", self.sum_goal) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        self.logger.record("info/sum_ran_into_lava", self.sum_lava) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if "collision" in infos: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if infos["collision"]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                self.sum_collision += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            self.logger.record("info/sum_collision", sum_collisions) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return True |