diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py index c5c431d..4a9d477 100644 --- a/examples/shields/rl/rllibutils.py +++ b/examples/shields/rl/rllibutils.py @@ -104,6 +104,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): print(F"Shielding is {self.mask_actions}") def create_action_mask(self): + print(f'shielding is {self.mask_actions}') if not self.mask_actions: ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) return ret @@ -185,9 +186,18 @@ def shielding_env_creater(config): probability_intended = args.probability_intended probability_displacement = args.probability_displacement - - env = gym.make(name, randomize_start=True,probability_intended=probability_intended, probability_displacement=probability_displacement) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding != ShieldingConfig.Disabled) + probability_turn_intended = args.probability_turn_intended + probability_turn_displacement = args.probability_turn_displacement + + + env = gym.make(name, + randomize_start=True, + probability_intended=probability_intended, + probability_displacement=probability_displacement, + probability_turn_displacement=probability_turn_displacement, + probability_turn_intended=probability_turn_intended) + + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0,