|
|
@ -154,7 +154,7 @@ def create_log_dir(args): |
|
|
|
return log_dir |
|
|
|
|
|
|
|
def get_allowed_actions_mask(actions): |
|
|
|
action_mask = [0.0] * 3 + [1.0] * 4 |
|
|
|
action_mask = [0.0] * 7 |
|
|
|
actions_labels = [label for labels in actions for label in list(labels)] |
|
|
|
for action_label in actions_labels: |
|
|
|
if "move" in action_label: |
|
|
@ -163,12 +163,21 @@ def get_allowed_actions_mask(actions): |
|
|
|
action_mask[0] = 1.0 |
|
|
|
elif "right" in action_label: |
|
|
|
action_mask[1] = 1.0 |
|
|
|
elif "pickup" in action_label: |
|
|
|
action_mask[3] = 1.0 |
|
|
|
elif "drop" in action_label: |
|
|
|
action_mask[4] = 1.0 |
|
|
|
elif "toggle" in action_label: |
|
|
|
action_mask[5] = 1.0 |
|
|
|
elif "done" in action_label: |
|
|
|
action_mask[6] = 1.0 |
|
|
|
return action_mask |
|
|
|
|
|
|
|
def common_parser(): |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--env", |
|
|
|
help="gym environment to load", |
|
|
|
choices=gym.envs.registry.keys(), |
|
|
|
default="MiniGrid-LavaSlipperyCliff-16x13-v0") |
|
|
|
|
|
|
|
parser.add_argument("--grid_file", default="grid.txt") |
|
|
|