|
|
@ -36,7 +36,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
|
|
|
|
|
def step(self, action): |
|
|
|
obs, rew, done, truncated, info = self.env.step(action) |
|
|
|
|
|
|
|
info["no_shield_action"] = not self.shield.has_key(self.env.get_symbolic_state()) |
|
|
|
return obs, rew, done, truncated, info |
|
|
|
|
|
|
|
def parse_sb3_arguments(): |
|
|
@ -104,6 +104,7 @@ class InfoCallback(BaseCallback): |
|
|
|
self.sum_collisions = 0 |
|
|
|
self.sum_opened_door = 0 |
|
|
|
self.sum_picked_up = 0 |
|
|
|
self.no_shield_action = 0 |
|
|
|
|
|
|
|
def _on_step(self) -> bool: |
|
|
|
infos = self.locals["infos"][0] |
|
|
@ -125,4 +126,8 @@ class InfoCallback(BaseCallback): |
|
|
|
if infos["picked_up"]: |
|
|
|
self.sum_picked_up += 1 |
|
|
|
self.logger.record("info/sum_picked_up", self.sum_picked_up) |
|
|
|
if "no_shield_action" in infos: |
|
|
|
if infos["no_shield_action"]: |
|
|
|
self.no_shield_action += 1 |
|
|
|
self.logger.record("info/no_shield_action", self.no_shield_action) |
|
|
|
return True |