diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index 28936b9..78d9ee1 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -116,15 +116,17 @@ class MiniGridShieldHandler(ShieldHandler): state_valuations = model.state_valuations choice_labeling = model.choice_labeling - #stormpy.shields.export_shield(model, shield, "current.shield") + #stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/current.shield") + print(f"LOG: Starting to translate shield...") + tic() for stateID in model.states: choice = shield_scheduler.get_choice(stateID) choices = choice.choice_map state_valuation = state_valuations.get_string(stateID) - ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=([a-zA-Z0-9]+)', state_valuation)) - booleans = dict(re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]', state_valuation)) #TODO does not parse everything correctly? - + ints = dict(re.findall(r'([a-zA-Z][_a-zA-Z0-9]+)=(-?[a-zA-Z0-9]+)', state_valuation)) + booleans = re.findall(r'(\!?)([a-zA-Z][_a-zA-Z0-9]+)[\s\t]+', state_valuation) + booleans = {b[1]: False if b[0] == "!" else True for b in booleans} if int(ints.get("previousActionAgent", 3)) != 3: continue if int(ints.get("clock", 0)) != 0: @@ -132,6 +134,7 @@ class MiniGridShieldHandler(ShieldHandler): state = to_state(ints, booleans) action_dictionary[state] = get_allowed_actions_mask([choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1])) for choice in choices]) + toc() return action_dictionary