Browse Source

log if shield cannot provide mask

refactoring
sp 7 months ago
parent
commit
eec87050eb
  1. 7
      examples/shields/rl/sb3utils.py

7
examples/shields/rl/sb3utils.py

@ -36,7 +36,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def step(self, action):
obs, rew, done, truncated, info = self.env.step(action)
info["no_shield_action"] = not self.shield.has_key(self.env.get_symbolic_state())
return obs, rew, done, truncated, info
def parse_sb3_arguments():
@ -104,6 +104,7 @@ class InfoCallback(BaseCallback):
self.sum_collisions = 0
self.sum_opened_door = 0
self.sum_picked_up = 0
self.no_shield_action = 0
def _on_step(self) -> bool:
infos = self.locals["infos"][0]
@ -125,4 +126,8 @@ class InfoCallback(BaseCallback):
if infos["picked_up"]:
self.sum_picked_up += 1
self.logger.record("info/sum_picked_up", self.sum_picked_up)
if "no_shield_action" in infos:
if infos["no_shield_action"]:
self.no_shield_action += 1
self.logger.record("info/no_shield_action", self.no_shield_action)
return True
Loading…
Cancel
Save