diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index f3a17db..ae86c40 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -18,7 +18,7 @@ from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name from shieldhandlers import MiniGridShieldHandler, create_shield_query from torch.utils.tensorboard import SummaryWriter -from callbacks import MyCallbacks +from callbacks import MyCallbacks, ShieldInfoCallback def shielding_env_creater(config):