From eec87050eb8e338de91bcd66fe8310ee81e9cad2 Mon Sep 17 00:00:00 2001 From: sp Date: Thu, 28 Mar 2024 14:29:21 +0100 Subject: [PATCH] log if shield cannot provide mask --- examples/shields/rl/sb3utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index c2aa564..b177763 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -36,7 +36,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): def step(self, action): obs, rew, done, truncated, info = self.env.step(action) - + info["no_shield_action"] = not self.shield.has_key(self.env.get_symbolic_state()) return obs, rew, done, truncated, info def parse_sb3_arguments(): @@ -104,6 +104,7 @@ class InfoCallback(BaseCallback): self.sum_collisions = 0 self.sum_opened_door = 0 self.sum_picked_up = 0 + self.no_shield_action = 0 def _on_step(self) -> bool: infos = self.locals["infos"][0] @@ -125,4 +126,8 @@ class InfoCallback(BaseCallback): if infos["picked_up"]: self.sum_picked_up += 1 self.logger.record("info/sum_picked_up", self.sum_picked_up) + if "no_shield_action" in infos: + if infos["no_shield_action"]: + self.no_shield_action += 1 + self.logger.record("info/no_shield_action", self.no_shield_action) return True