|
@ -133,11 +133,12 @@ class MiniGridShieldHandler(ShieldHandler): |
|
|
|
|
|
|
|
|
return self.__create_shield_dict() |
|
|
return self.__create_shield_dict() |
|
|
|
|
|
|
|
|
|
|
|
def expname(args): |
|
|
|
|
|
return f"{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}" |
|
|
def create_log_dir(args): |
|
|
def create_log_dir(args): |
|
|
return f"{args.log_dir}/{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}" |
|
|
|
|
|
|
|
|
|
|
|
def test_name(args): |
|
|
|
|
|
return f"{args.expname}" |
|
|
|
|
|
|
|
|
log_dir = f"{args.log_dir}/{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{expname(args)}" |
|
|
|
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
return log_dir |
|
|
|
|
|
|
|
|
def get_allowed_actions_mask(actions): |
|
|
def get_allowed_actions_mask(actions): |
|
|
action_mask = [0.0] * 3 + [1.0] * 4 |
|
|
action_mask = [0.0] * 3 + [1.0] * 4 |
|
@ -155,7 +156,7 @@ def common_parser(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--env", |
|
|
parser.add_argument("--env", |
|
|
help="gym environment to load", |
|
|
help="gym environment to load", |
|
|
default="MiniGrid-LavaSlipperyCliff-16x12-v0") |
|
|
|
|
|
|
|
|
default="MiniGrid-LavaSlipperyCliff-16x13-v0") |
|
|
|
|
|
|
|
|
parser.add_argument("--grid_file", default="grid.txt") |
|
|
parser.add_argument("--grid_file", default="grid.txt") |
|
|
parser.add_argument("--prism_output_file", default="grid.prism") |
|
|
parser.add_argument("--prism_output_file", default="grid.prism") |
|
|