diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 1144d36..611ac27 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -42,7 +42,7 @@ def create_log_dir(args): return F"{args.log_dir}sh:{args.shielding}-env:{args.env}" def test_name(args): - return F"sh:{args.shielding}-env:{args.env}" + return F"{args.expname}/sh:{args.shielding}-env:{args.env}" def get_action_index_mapping(actions): for action_str in actions: @@ -94,7 +94,7 @@ def parse_arguments(argparse): parser.add_argument("--workers", type=int, default=1) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--steps", default=20_000, type=int) - + parser.add_argument("--expname", default="exp") args = parser.parse_args() return args