diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index 454552b..bcae0a3 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -10,7 +10,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, env, shield_handler : MiniGridShieldHandler, - create_shield_at_reset = True, + create_shield_at_reset = False, ): super().__init__(env) self.shield_handler = shield_handler