diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index c70e9b2..448f8d1 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -135,6 +135,7 @@ class MiniGridShieldHandler(ShieldHandler): action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) toc() + print(f"{len(action_dictionary)} states in the shield") return action_dictionary