diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py index deaf959..a22dda1 100644 --- a/examples/shields/rl/rllibutils.py +++ b/examples/shields/rl/rllibutils.py @@ -97,7 +97,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): } ) self.shield_creator = shield_creator - self.create_shield_at_reset = create_shield_at_reset + self.create_shield_at_reset = False # TODO self.shield = shield_creator.create_shield(env=self.env) self.mask_actions = mask_actions self.shield_query_creator = shield_query_creator