From 5e824757d7b4002692d4d74b40d9ec54223228cd Mon Sep 17 00:00:00 2001 From: sp Date: Thu, 14 Mar 2024 22:32:43 +0100 Subject: [PATCH] disabled debug reenabled training --- examples/shields/rl/13_minigridsb.py | 1 - examples/shields/rl/utils.py | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 9f61046..7c36bc7 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -88,7 +88,6 @@ def main(): imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - assert(False) model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) model.save(f"{log_dir}/{expname(args)}") diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 8123ceb..d74856b 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -125,23 +125,23 @@ class MiniGridShieldHandler(ShieldHandler): choices = choice.choice_map state_valuation = state_valuations.get_string(stateID) - print(state_valuation) - print(choices) + #print(state_valuation) + #print(choices) ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) booleans = {b[1]: False if b[0] == "!" else True for b in booleans} - print(ints, booleans) + #print(ints, booleans) if int(ints.get("previousActionAgent", 3)) != 3: continue if int(ints.get("clock", 0)) != 0: continue state = to_state(ints, booleans) - print(f"{state} got added with actions:") - print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])) + #print(f"{state} got added with actions:") + #print(get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices])) action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) toc() - print(f"{len(action_dictionary)} states in the shield") + #print(f"{len(action_dictionary)} states in the shield") return action_dictionary