from shield_handler import ShieldHandler import tempfile, datetime, shutil import os import stormpy import stormpy.core import stormpy.simulator from itertools import chain, combinations import re from minigrid.core.state import to_state, State from utils import tic, toc, get_allowed_actions_mask_for_labels, id_to_symbolic, StateClassification from loguru import logger class MiniGridShieldHandler(ShieldHandler): ACTION_SPACE_SIZE = 4 def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, env, prism_config=None, shield_value=0.9, shield_comparison='absolute', ignore_view: bool = False, nocleanup=False, prism_file=None) -> None: logger.success("MiniGridShieldHandler ctor called") os.makedirs("shielding_files", exist_ok=True) self.tmp_dir_name = f"shielding_files/{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}" os.makedirs(self.tmp_dir_name) self.grid_file = self.tmp_dir_name + "/" + grid_file self.grid_to_prism_binary = grid_to_prism_binary self.prism_path = self.tmp_dir_name + "/" + prism_path self.prism_config = prism_config self.prism_file = prism_file self.action_dictionary = None self.ignore_view = ignore_view self.env = env self.dtcontrol_csvs = dict() self.dtcontrol_csvs["base"] = set() self.dtcontrol_csvs["base_basic_preds"] = set() self.dtcontrol_csvs["formulas_critical"] = list() self.dtcontrol_csvs["formulas_dangerous"] = list() self.dtcontrol_csvs["safe_unsafe"] = set() self.dtcontrol_csvs["critical_base"] = list() self.dtcontrol_csvs["dangerous_base"] = list() self.dtcontrol_csvs["critical_state_action"] = list() self.dtcontrol_csvs["dangerous_state_action"] = list() self.formula = formula self.shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, shield_value) self.fallback_shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, 1.0) self.shield_comparison = shield_comparison self.nocleanup = nocleanup def __del__(self): pass #if not self.nocleanup: # shutil.rmtree(self.tmp_dir_name) def __export_grid_to_text(self, env): with open(self.grid_file, "w") as f: f.write(env.printGrid(init=True)) def __create_prism(self): if self.prism_file is not None: logger.info(self.prism_file) logger.info(self.prism_path) shutil.copyfile(self.prism_file, self.prism_path) return if self.prism_config is None: result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}") logger.info(self.grid_file) logger.info(self.prism_path) else: result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}") assert result == 0, "Prism file could not be generated" def __create_shield_dict(self): self.program = stormpy.parse_prism_program(self.prism_path) f_comb = self.formula[0] + " | ".join(self.formula[1:-1]) + self.formula[-1] f_sep = [self.formula[0] + f + self.formula[-1] for f in self.formula[1:-1]] s = f_sep[1:-1] #formula_powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))) f_sep.append(f_comb) formulas_comb = stormpy.parse_properties_for_prism_program(f_comb, self.program) formulas_sep = stormpy.parse_properties_for_prism_program("; ".join(f_sep), self.program) options = stormpy.BuilderOptions([p.raw_formula for p in formulas_sep]) options.set_build_state_valuations(True) options.set_build_choice_labels(True) options.set_build_all_labels() print(f"LOG: Starting with explicit model creation...") tic() model = stormpy.build_sparse_model_with_options(self.program, options) self.model = model toc() print(f"LOG: Starting with model checking...") tic() result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.shield_expression) fallback_result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.fallback_shield_expression) toc() assert result.has_shield shield = result.shield fallback_shield = fallback_result.shield self.fallback_shield = fallback_shield.construct() action_dictionary = dict() self.shield_scheduler = shield.construct() self.state_valuations = model.state_valuations self.choice_labeling = model.choice_labeling stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield") stormpy.shields.export_shield(model, fallback_shield, self.tmp_dir_name + "/fallback_shield") self.dangerous_states = list() print(f"LOG: Starting to translate shield...") tic() for state_id in model.states: state = id_to_symbolic(state_id, self.state_valuations) if state == None: continue if 'deadlock' in state_id.labels or self.formula[1] in state_id.labels: continue choice = self.shield_scheduler.get_choice(state_id) choices = choice.choice_map state = id_to_symbolic(state_id, self.state_valuations) if not state: continue action_dictionary[state] = get_allowed_actions_mask_for_labels([self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices], self.ignore_view) is_dangerous = False if sum(action_dictionary[state]) == 0: choice = self.fallback_shield.get_choice(state_id) choices = choice.choice_map try: logger.warning(f"State: [colAgent={state.colAgent}, rowAgent={state.rowAgent}] is dangerous...") fallback_value = min(choices, key=lambda c : c[0])[0] fallback_actions = [c[1] for c in choices if c[0] == fallback_value] choice_labels = [self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices] enabled_fallback_actions_labels = [choice_labels[i] for i in fallback_actions] for c in fallback_actions: action_dictionary[state][c] = 1.0 action_dictionary[state] = get_allowed_actions_mask_for_labels(enabled_fallback_actions_labels, self.ignore_view) logger.warning(f"... only allowing {enabled_fallback_actions_labels}") self.dangerous_states.append(state) except ValueError as e: logger.error(f"{state=} is unsafe, {e=}") continue except IndexError as e: logger.error(f"{state=} is unsafe, {e=}") continue toc() self.action_dictionary = action_dictionary logger.success(f"Length of the action_dictionary of the shield:\t\t{len(action_dictionary)}") return action_dictionary def create_shield(self, **kwargs): if self.action_dictionary is not None: return self.action_dictionary env = kwargs["env"] self.__export_grid_to_text(env) logger.info("creating prism file") self.__create_prism() print("Computing new shield") return self.__create_shield_dict() def get_model(self): return self.model def get_dtcontrol_csv(self, action_comparison="absolute"): pre = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables()} 1"] pre_basic = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(only_basic_preds = True)} 1"] pre_action = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(action=True)} 1"] meta_data = dict() meta_data_action = dict() meta_data_basic = dict() types, values = self.translator.get_column_types_and_categories() meta_data["x_column_types"] = types meta_data["x_category_names"] = values types, values = self.translator.get_column_types_and_categories(action_comparison=action_comparison, action=True) meta_data_action["x_column_types"] = types meta_data_action["x_category_names"] = values types, values = self.translator.get_column_types_and_categories(only_basic_preds=True) meta_data_basic["x_column_types"] = types meta_data_basic["x_category_names"] = values result = dict() for key in self.dtcontrol_csvs: if key == "formulas_dangerous" or key == "formulas_critical": result[key] = [pre + v for v in self.dtcontrol_csvs[key]] elif "state_action" in key: result[key] = pre_action + self.dtcontrol_csvs[key] elif "basic_preds" in key: result[key] = pre_basic + list(self.dtcontrol_csvs[key]) else: result[key] = pre + list(self.dtcontrol_csvs[key]) return result, meta_data, meta_data_action, meta_data_basic @property def feature_space(self, only_basic_preds=False) -> list[str]: try: return self.translator.get_feature_space(only_basic_preds=only_basic_preds) except AttributeError: self._feature_space = id_to_symbolic(self.model.states[0], self.state_valuations).feature_space #fs = [] #adv_dist = 3 #for i in range(0, len(self._feature_space)): # fs.append(self._feature_space[i]) # feature = self._feature_space[i] # if i >= adv_dist and feature.startswith("view"): # color = feature.replace("view", '') # fs.append(f"dist{color}") #self._feature_space = fs return self._feature_space @property def num_features(self, only_basic_preds=False) -> int: try: return self._num_features except AttributeError: self._num_features = len(self.feature_space, only_basic_preds) return self._num_features @property def length_shield(self) -> int: return len(self.action_dictionary)