|
@ -12,8 +12,6 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
mask_actions=True, |
|
|
mask_actions=True, |
|
|
): |
|
|
): |
|
|
super().__init__(env) |
|
|
super().__init__(env) |
|
|
self.observation_space = env.observation_space.spaces["image"] |
|
|
|
|
|
|
|
|
|
|
|
self.shield_handler = shield_handler |
|
|
self.shield_handler = shield_handler |
|
|
self.mask_actions = mask_actions |
|
|
self.mask_actions = mask_actions |
|
|
self.create_shield_at_reset = create_shield_at_reset |
|
|
self.create_shield_at_reset = create_shield_at_reset |
|
@ -33,11 +31,10 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): |
|
|
if self.create_shield_at_reset and self.mask_actions: |
|
|
if self.create_shield_at_reset and self.mask_actions: |
|
|
shield = self.shield_handler.create_shield(env=self.env) |
|
|
shield = self.shield_handler.create_shield(env=self.env) |
|
|
self.shield = shield |
|
|
self.shield = shield |
|
|
return obs["image"], infos |
|
|
|
|
|
|
|
|
return obs, infos |
|
|
|
|
|
|
|
|
def step(self, action): |
|
|
def step(self, action): |
|
|
orig_obs, rew, done, truncated, info = self.env.step(action) |
|
|
|
|
|
obs = orig_obs["image"] |
|
|
|
|
|
|
|
|
obs, rew, done, truncated, info = self.env.step(action) |
|
|
|
|
|
|
|
|
return obs, rew, done, truncated, info |
|
|
return obs, rew, done, truncated, info |
|
|
|
|
|
|
|
|