Browse Source

ensure that exp log directory exists

refactoring
sp 11 months ago
parent
commit
8cbbef4006
  1. 11
      examples/shields/rl/utils.py

11
examples/shields/rl/utils.py

@ -133,11 +133,12 @@ class MiniGridShieldHandler(ShieldHandler):
return self.__create_shield_dict() return self.__create_shield_dict()
def expname(args):
return f"{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}"
def create_log_dir(args): def create_log_dir(args):
return f"{args.log_dir}/{args.env}_{args.shielding}_{args.shield_comparison}_{args.shield_value}"
def test_name(args):
return f"{args.expname}"
log_dir = f"{args.log_dir}/{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{expname(args)}"
os.makedirs(log_dir, exist_ok=True)
return log_dir
def get_allowed_actions_mask(actions): def get_allowed_actions_mask(actions):
action_mask = [0.0] * 3 + [1.0] * 4 action_mask = [0.0] * 3 + [1.0] * 4
@ -155,7 +156,7 @@ def common_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--env", parser.add_argument("--env",
help="gym environment to load", help="gym environment to load",
default="MiniGrid-LavaSlipperyCliff-16x12-v0")
default="MiniGrid-LavaSlipperyCliff-16x13-v0")
parser.add_argument("--grid_file", default="grid.txt") parser.add_argument("--grid_file", default="grid.txt")
parser.add_argument("--prism_output_file", default="grid.prism") parser.add_argument("--prism_output_file", default="grid.prism")

Loading…
Cancel
Save