diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index dcf20d2..29571c8 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -102,6 +102,7 @@ class InfoCallback(BaseCallback): self.sum_goal = 0 self.sum_lava = 0 self.sum_collisions = 0 + self.sum_opened_door = 0 def _on_step(self) -> bool: infos = self.locals["infos"][0] @@ -115,4 +116,8 @@ class InfoCallback(BaseCallback): if infos["collision"]: self.sum_collisions += 1 self.logger.record("info/sum_collision", self.sum_collisions) + if "opened_door" in infos: + if infos["opened_door"]: + self.sum_opened_door += 1 + self.logger.record("info/sum_opened_door", self.sum_opened_door) return True