Browse Source

removed observation changes from shielding wrapper

refactoring
sp 10 months ago
parent
commit
62c1198f25
  1. 7
      examples/shields/rl/sb3utils.py

7
examples/shields/rl/sb3utils.py

@ -12,8 +12,6 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
mask_actions=True, mask_actions=True,
): ):
super().__init__(env) super().__init__(env)
self.observation_space = env.observation_space.spaces["image"]
self.shield_handler = shield_handler self.shield_handler = shield_handler
self.mask_actions = mask_actions self.mask_actions = mask_actions
self.create_shield_at_reset = create_shield_at_reset self.create_shield_at_reset = create_shield_at_reset
@ -33,11 +31,10 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
if self.create_shield_at_reset and self.mask_actions: if self.create_shield_at_reset and self.mask_actions:
shield = self.shield_handler.create_shield(env=self.env) shield = self.shield_handler.create_shield(env=self.env)
self.shield = shield self.shield = shield
return obs["image"], infos
return obs, infos
def step(self, action): def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
obs, rew, done, truncated, info = self.env.step(action)
return obs, rew, done, truncated, info return obs, rew, done, truncated, info

Loading…
Cancel
Save