Browse Source

args for turn prob

refactoring
Thomas Knoll 11 months ago
parent
commit
555511bd34
  1. 1
      examples/shields/rl/rllibutils.py
  2. 2
      examples/shields/rl/utils.py

1
examples/shields/rl/rllibutils.py

@ -104,7 +104,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
print(F"Shielding is {self.mask_actions}") print(F"Shielding is {self.mask_actions}")
def create_action_mask(self): def create_action_mask(self):
print(f'shielding is {self.mask_actions}')
if not self.mask_actions: if not self.mask_actions:
ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
return ret return ret

2
examples/shields/rl/utils.py

@ -313,6 +313,8 @@ def parse_arguments(argparse):
parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--shield_value", default=0.9, type=float)
parser.add_argument("--probability_displacement", default=1/4, type=float) parser.add_argument("--probability_displacement", default=1/4, type=float)
parser.add_argument("--probability_intended", default=3/4, type=float) parser.add_argument("--probability_intended", default=3/4, type=float)
parser.add_argument("--probability_turn_displacement", default=0/4, type=float)
parser.add_argument("--probability_turn_intended", default=4/4, type=float)
parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute'])
# parser.add_argument("--random_starts", default=1, type=int) # parser.add_argument("--random_starts", default=1, type=int)
args = parser.parse_args() args = parser.parse_args()

Loading…
Cancel
Save