Browse Source

fixed shielding

refactoring
Thomas Knoll 11 months ago
parent
commit
f8a3c52b9c
  1. 16
      examples/shields/rl/rllibutils.py

16
examples/shields/rl/rllibutils.py

@ -104,6 +104,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
print(F"Shielding is {self.mask_actions}") print(F"Shielding is {self.mask_actions}")
def create_action_mask(self): def create_action_mask(self):
print(f'shielding is {self.mask_actions}')
if not self.mask_actions: if not self.mask_actions:
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
return ret return ret
@ -185,9 +186,18 @@ def shielding_env_creater(config):
probability_intended = args.probability_intended probability_intended = args.probability_intended
probability_displacement = args.probability_displacement probability_displacement = args.probability_displacement
env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled)
probability_turn_intended = args.probability_turn_intended
probability_turn_displacement = args.probability_turn_displacement
env = gym.make(name,
randomize_start=True,
probability_intended=probability_intended,
probability_displacement=probability_displacement,
probability_turn_displacement=probability_turn_displacement,
probability_turn_intended=probability_turn_intended)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env, env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0, config.vector_index if hasattr(config, "vector_index") else 0,

Loading…
Cancel
Save