Browse Source

disabled debug reenabled training

refactoring
sp 8 months ago
parent
commit
5e824757d7
  1. 1
      examples/shields/rl/13_minigridsb.py
  2. 12
      examples/shields/rl/utils.py

1
examples/shields/rl/13_minigridsb.py

@ -88,7 +88,6 @@ def main():
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
assert(False)
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
model.save(f"{log_dir}/{expname(args)}") model.save(f"{log_dir}/{expname(args)}")

12
examples/shields/rl/utils.py

@ -125,23 +125,23 @@ class MiniGridShieldHandler(ShieldHandler):
choices = choice.choice_map choices = choice.choice_map
state_valuation = state_valuations.get_string(stateID) state_valuation = state_valuations.get_string(stateID)
print(state_valuation)
print(choices)
#print(state_valuation)
#print(choices)
ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation))
booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation)
booleans = {b[1]: False if b[0] == "!" else True for b in booleans} booleans = {b[1]: False if b[0] == "!" else True for b in booleans}
print(ints, booleans)
#print(ints, booleans)
if int(ints.get("previousActionAgent", 3)) != 3: if int(ints.get("previousActionAgent", 3)) != 3:
continue continue
if int(ints.get("clock", 0)) != 0: if int(ints.get("clock", 0)) != 0:
continue continue
state = to_state(ints, booleans) state = to_state(ints, booleans)
print(f"{state} got added with actions:")
print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]))
#print(f"{state} got added with actions:")
#print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]))
action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])
toc() toc()
print(f"{len(action_dictionary)} states in the shield")
#print(f"{len(action_dictionary)} states in the shield")
return action_dictionary return action_dictionary

Loading…
Cancel
Save