diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index b49c38b..f18f2af 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -12,8 +12,6 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): mask_actions=True, ): super().__init__(env) - self.observation_space = env.observation_space.spaces["image"] - self.shield_handler = shield_handler self.mask_actions = mask_actions self.create_shield_at_reset = create_shield_at_reset @@ -33,11 +31,10 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): if self.create_shield_at_reset and self.mask_actions: shield = self.shield_handler.create_shield(env=self.env) self.shield = shield - return obs["image"], infos + return obs, infos def step(self, action): - orig_obs, rew, done, truncated, info = self.env.step(action) - obs = orig_obs["image"] + obs, rew, done, truncated, info = self.env.step(action) return obs, rew, done, truncated, info