diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 569d0dc..150d7ea 100755 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -25,7 +25,7 @@ def shielding_env_creater(config): shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config) - env = gym.make(name) + env = gym.make(name, randomize_start=True) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding != ShieldingConfig.Disabled, diff --git a/examples/shields/rl/wrappers.py b/examples/shields/rl/wrappers.py index e706ff0..bbd9bac 100644 --- a/examples/shields/rl/wrappers.py +++ b/examples/shields/rl/wrappers.py @@ -134,7 +134,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") assert(False) - allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] + allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] if allowed_action.prob == 0 and allowed: assert False if allowed: