Browse Source

randomize start

refactoring
Thomas Knoll 11 months ago
parent
commit
5c9064ecbe
  1. 2
      examples/shields/rl/11_minigridrl.py
  2. 2
      examples/shields/rl/wrappers.py

2
examples/shields/rl/11_minigridrl.py

@ -25,7 +25,7 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula, args.shield_value, args.prism_config)
env = gym.make(name)
env = gym.make(name, randomize_start=True)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query, shield_query_creator=create_shield_query,
mask_actions=args.shielding != ShieldingConfig.Disabled, mask_actions=args.shielding != ShieldingConfig.Disabled,

2
examples/shields/rl/wrappers.py

@ -134,7 +134,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}") print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}")
assert(False) assert(False)
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
allowed = 1.0 # random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
if allowed_action.prob == 0 and allowed: if allowed_action.prob == 0 and allowed:
assert False assert False
if allowed: if allowed:

Loading…
Cancel
Save