|
@ -104,6 +104,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): |
|
|
print(F"Shielding is {self.mask_actions}") |
|
|
print(F"Shielding is {self.mask_actions}") |
|
|
|
|
|
|
|
|
def create_action_mask(self): |
|
|
def create_action_mask(self): |
|
|
|
|
|
print(f'shielding is {self.mask_actions}') |
|
|
if not self.mask_actions: |
|
|
if not self.mask_actions: |
|
|
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) |
|
|
return ret |
|
|
return ret |
|
@ -185,9 +186,18 @@ def shielding_env_creater(config): |
|
|
|
|
|
|
|
|
probability_intended = args.probability_intended |
|
|
probability_intended = args.probability_intended |
|
|
probability_displacement = args.probability_displacement |
|
|
probability_displacement = args.probability_displacement |
|
|
|
|
|
probability_turn_intended = args.probability_turn_intended |
|
|
|
|
|
probability_turn_displacement = args.probability_turn_displacement |
|
|
|
|
|
|
|
|
env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) |
|
|
|
|
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env = gym.make(name, |
|
|
|
|
|
randomize_start=True, |
|
|
|
|
|
probability_intended=probability_intended, |
|
|
|
|
|
probability_displacement=probability_displacement, |
|
|
|
|
|
probability_turn_displacement=probability_turn_displacement, |
|
|
|
|
|
probability_turn_intended=probability_turn_intended) |
|
|
|
|
|
|
|
|
|
|
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) |
|
|
|
|
|
|
|
|
env = OneHotShieldingWrapper(env, |
|
|
env = OneHotShieldingWrapper(env, |
|
|
config.vector_index if hasattr(config, "vector_index") else 0, |
|
|
config.vector_index if hasattr(config, "vector_index") else 0, |
|
|