diff --git a/examples/shields/rl/rllibutils.py b/examples/shields/rl/rllibutils.py index 4a9d477..ea48af3 100644 --- a/examples/shields/rl/rllibutils.py +++ b/examples/shields/rl/rllibutils.py @@ -104,7 +104,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): print(F"Shielding is {self.mask_actions}") def create_action_mask(self): - print(f'shielding is {self.mask_actions}') if not self.mask_actions: ret = np.array([1.0] * self.max_available_actions, dtype=np.int8) return ret diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index c0eeb70..0816a54 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -313,6 +313,8 @@ def parse_arguments(argparse): parser.add_argument("--shield_value", default=0.9, type=float) parser.add_argument("--probability_displacement", default=1/4, type=float) parser.add_argument("--probability_intended", default=3/4, type=float) + parser.add_argument("--probability_turn_displacement", default=0/4, type=float) + parser.add_argument("--probability_turn_intended", default=4/4, type=float) parser.add_argument("--shield_comparision", default='relative', choices=['relative', 'absolute']) # parser.add_argument("--random_starts", default=1, type=int) args = parser.parse_args()