diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 9024c4d..28936b9 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -45,6 +45,15 @@ class ShieldingConfig(Enum): def __str__(self) -> str: return self.value +def shield_needed(shielding): + return shielding in [ShieldingConfig.Training, ShieldingConfig.Evaluation, ShieldingConfig.Full] + +def shielded_evaluation(shielding): + return shielding in [ShieldingConfig.Evaluation, ShieldingConfig.Full] + +def shielded_training(shielding): + return shielding in [ShieldingConfig.Training, ShieldingConfig.Full] + class ShieldHandler(ABC): def __init__(self) -> None: pass