diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py index ea48af3..deaf959 100644 --- a/examples/shields/rl/rllibutils.py +++ b/examples/shields/rl/rllibutils.py @@ -82,11 +82,11 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): class MiniGridShieldingWrapper(gym.core.Wrapper): - def __init__(self, - env, + def __init__(self, + env, shield_creator : MiniGridShieldHandler, shield_query_creator, - create_shield_at_reset=True, + create_shield_at_reset=False, mask_actions=True): super(MiniGridShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n