From 27004e091603bcf363251a8c2ce5b622e4c4a76a Mon Sep 17 00:00:00 2001 From: sp Date: Fri, 19 Jan 2024 10:39:49 +0100 Subject: [PATCH] check all actions for mask --- examples/shields/rl/utils.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 78d9ee1..42560e9 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -154,7 +154,7 @@ def create_log_dir(args): return log_dir def get_allowed_actions_mask(actions): - action_mask = [0.0] * 3 + [1.0] * 4 + action_mask = [0.0] * 7 actions_labels = [label for labels in actions for label in list(labels)] for action_label in actions_labels: if "move" in action_label: @@ -163,12 +163,21 @@ def get_allowed_actions_mask(actions): action_mask[0] = 1.0 elif "right" in action_label: action_mask[1] = 1.0 + elif "pickup" in action_label: + action_mask[3] = 1.0 + elif "drop" in action_label: + action_mask[4] = 1.0 + elif "toggle" in action_label: + action_mask[5] = 1.0 + elif "done" in action_label: + action_mask[6] = 1.0 return action_mask def common_parser(): parser = argparse.ArgumentParser() parser.add_argument("--env", help="gym environment to load", + choices=gym.envs.registry.keys(), default="MiniGrid-LavaSlipperyCliff-16x13-v0") parser.add_argument("--grid_file", default="grid.txt")