Browse Source

check all actions for mask

refactoring
sp 9 months ago
parent
commit
27004e0916
  1. 11
      examples/shields/rl/utils.py

11
examples/shields/rl/utils.py

@ -154,7 +154,7 @@ def create_log_dir(args):
return log_dir
def get_allowed_actions_mask(actions):
action_mask = [0.0] * 3 + [1.0] * 4
action_mask = [0.0] * 7
actions_labels = [label for labels in actions for label in list(labels)]
for action_label in actions_labels:
if "move" in action_label:
@ -163,12 +163,21 @@ def get_allowed_actions_mask(actions):
action_mask[0] = 1.0
elif "right" in action_label:
action_mask[1] = 1.0
elif "pickup" in action_label:
action_mask[3] = 1.0
elif "drop" in action_label:
action_mask[4] = 1.0
elif "toggle" in action_label:
action_mask[5] = 1.0
elif "done" in action_label:
action_mask[6] = 1.0
return action_mask
def common_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--env",
help="gym environment to load",
choices=gym.envs.registry.keys(),
default="MiniGrid-LavaSlipperyCliff-16x13-v0")
parser.add_argument("--grid_file", default="grid.txt")

Loading…
Cancel
Save