|
@ -82,11 +82,11 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
def __init__(self, |
|
|
|
|
|
env, |
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
env, |
|
|
shield_creator : MiniGridShieldHandler, |
|
|
shield_creator : MiniGridShieldHandler, |
|
|
shield_query_creator, |
|
|
shield_query_creator, |
|
|
create_shield_at_reset=True, |
|
|
|
|
|
|
|
|
create_shield_at_reset=False, |
|
|
mask_actions=True): |
|
|
mask_actions=True): |
|
|
super(MiniGridShieldingWrapper, self).__init__(env) |
|
|
super(MiniGridShieldingWrapper, self).__init__(env) |
|
|
self.max_available_actions = env.action_space.n |
|
|
self.max_available_actions = env.action_space.n |
|
|