diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index 29571c8..25b7c51 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -103,6 +103,7 @@ class InfoCallback(BaseCallback): self.sum_lava = 0 self.sum_collisions = 0 self.sum_opened_door = 0 + self.sum_picked_up = 0 def _on_step(self) -> bool: infos = self.locals["infos"][0] @@ -120,4 +121,8 @@ class InfoCallback(BaseCallback): if infos["opened_door"]: self.sum_opened_door += 1 self.logger.record("info/sum_opened_door", self.sum_opened_door) + if "picked_up" in infos: + if infos["picked_up"]: + self.sum_picked_up += 1 + self.logger.record("info/sum_picked_up", self.sum_picked_up) return True