You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

232 lines
11 KiB

from shield_handler import ShieldHandler
import tempfile, datetime, shutil
import os
import stormpy
import stormpy.core
import stormpy.simulator
from itertools import chain, combinations
import re
from minigrid.core.state import to_state, State
from utils import tic, toc, get_allowed_actions_mask_for_labels, id_to_symbolic, StateClassification
from loguru import logger
class MiniGridShieldHandler(ShieldHandler):
ACTION_SPACE_SIZE = 4
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, env, prism_config=None, shield_value=0.9, shield_comparison='absolute', ignore_view: bool = False, nocleanup=False, prism_file=None) -> None:
logger.success("MiniGridShieldHandler ctor called")
os.makedirs("shielding_files", exist_ok=True)
self.tmp_dir_name = f"shielding_files/{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
os.makedirs(self.tmp_dir_name)
self.grid_file = self.tmp_dir_name + "/" + grid_file
self.grid_to_prism_binary = grid_to_prism_binary
self.prism_path = self.tmp_dir_name + "/" + prism_path
self.prism_config = prism_config
self.prism_file = prism_file
self.action_dictionary = None
self.ignore_view = ignore_view
self.env = env
self.dtcontrol_csvs = dict()
self.dtcontrol_csvs["base"] = set()
self.dtcontrol_csvs["base_basic_preds"] = set()
self.dtcontrol_csvs["formulas_critical"] = list()
self.dtcontrol_csvs["formulas_dangerous"] = list()
self.dtcontrol_csvs["safe_unsafe"] = set()
self.dtcontrol_csvs["critical_base"] = list()
self.dtcontrol_csvs["dangerous_base"] = list()
self.dtcontrol_csvs["critical_state_action"] = list()
self.dtcontrol_csvs["dangerous_state_action"] = list()
self.formula = formula
self.shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, shield_value)
self.fallback_shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, 1.0)
self.shield_comparison = shield_comparison
self.nocleanup = nocleanup
def __del__(self):
pass
#if not self.nocleanup:
# shutil.rmtree(self.tmp_dir_name)
def __export_grid_to_text(self, env):
with open(self.grid_file, "w") as f:
f.write(env.printGrid(init=True))
def __create_prism(self):
if self.prism_file is not None:
logger.info(self.prism_file)
logger.info(self.prism_path)
shutil.copyfile(self.prism_file, self.prism_path)
return
if self.prism_config is None:
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
logger.info(self.grid_file)
logger.info(self.prism_path)
else:
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
assert result == 0, "Prism file could not be generated"
def __create_shield_dict(self):
self.program = stormpy.parse_prism_program(self.prism_path)
f_comb = self.formula[0] + " | ".join(self.formula[1:-1]) + self.formula[-1]
f_sep = [self.formula[0] + f + self.formula[-1] for f in self.formula[1:-1]]
s = f_sep[1:-1]
#formula_powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1)))
f_sep.append(f_comb)
formulas_comb = stormpy.parse_properties_for_prism_program(f_comb, self.program)
formulas_sep = stormpy.parse_properties_for_prism_program("; ".join(f_sep), self.program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas_sep])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
print(f"LOG: Starting with explicit model creation...")
tic()
model = stormpy.build_sparse_model_with_options(self.program, options)
self.model = model
toc()
print(f"LOG: Starting with model checking...")
tic()
result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.shield_expression)
fallback_result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.fallback_shield_expression)
toc()
assert result.has_shield
shield = result.shield
fallback_shield = fallback_result.shield
self.fallback_shield = fallback_shield.construct()
action_dictionary = dict()
self.shield_scheduler = shield.construct()
self.state_valuations = model.state_valuations
self.choice_labeling = model.choice_labeling
stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield")
stormpy.shields.export_shield(model, fallback_shield, self.tmp_dir_name + "/fallback_shield")
self.dangerous_states = list()
print(f"LOG: Starting to translate shield...")
tic()
for state_id in model.states:
state = id_to_symbolic(state_id, self.state_valuations)
if state == None:
continue
if 'deadlock' in state_id.labels or self.formula[1] in state_id.labels:
continue
choice = self.shield_scheduler.get_choice(state_id)
choices = choice.choice_map
state = id_to_symbolic(state_id, self.state_valuations)
if not state:
continue
action_dictionary[state] = get_allowed_actions_mask_for_labels([self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices], self.ignore_view)
is_dangerous = False
if sum(action_dictionary[state]) == 0:
choice = self.fallback_shield.get_choice(state_id)
choices = choice.choice_map
try:
logger.warning(f"State: [colAgent={state.colAgent}, rowAgent={state.rowAgent}] is dangerous...")
fallback_value = min(choices, key=lambda c : c[0])[0]
fallback_actions = [c[1] for c in choices if c[0] == fallback_value]
choice_labels = [self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices]
enabled_fallback_actions_labels = [choice_labels[i] for i in fallback_actions]
for c in fallback_actions:
action_dictionary[state][c] = 1.0
action_dictionary[state] = get_allowed_actions_mask_for_labels(enabled_fallback_actions_labels, self.ignore_view)
logger.warning(f"... only allowing {enabled_fallback_actions_labels}")
self.dangerous_states.append(state)
except ValueError as e:
logger.error(f"{state=} is unsafe, {e=}")
continue
except IndexError as e:
logger.error(f"{state=} is unsafe, {e=}")
continue
toc()
self.action_dictionary = action_dictionary
logger.success(f"Length of the action_dictionary of the shield:\t\t{len(action_dictionary)}")
return action_dictionary
def create_shield(self, **kwargs):
if self.action_dictionary is not None:
return self.action_dictionary
env = kwargs["env"]
self.__export_grid_to_text(env)
logger.info("creating prism file")
self.__create_prism()
print("Computing new shield")
return self.__create_shield_dict()
def get_model(self):
return self.model
def get_dtcontrol_csv(self, action_comparison="absolute"):
pre = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables()} 1"]
pre_basic = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(only_basic_preds = True)} 1"]
pre_action = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(action=True)} 1"]
meta_data = dict()
meta_data_action = dict()
meta_data_basic = dict()
types, values = self.translator.get_column_types_and_categories()
meta_data["x_column_types"] = types
meta_data["x_category_names"] = values
types, values = self.translator.get_column_types_and_categories(action_comparison=action_comparison, action=True)
meta_data_action["x_column_types"] = types
meta_data_action["x_category_names"] = values
types, values = self.translator.get_column_types_and_categories(only_basic_preds=True)
meta_data_basic["x_column_types"] = types
meta_data_basic["x_category_names"] = values
result = dict()
for key in self.dtcontrol_csvs:
if key == "formulas_dangerous" or key == "formulas_critical":
result[key] = [pre + v for v in self.dtcontrol_csvs[key]]
elif "state_action" in key:
result[key] = pre_action + self.dtcontrol_csvs[key]
elif "basic_preds" in key:
result[key] = pre_basic + list(self.dtcontrol_csvs[key])
else:
result[key] = pre + list(self.dtcontrol_csvs[key])
return result, meta_data, meta_data_action, meta_data_basic
@property
def feature_space(self, only_basic_preds=False) -> list[str]:
try:
return self.translator.get_feature_space(only_basic_preds=only_basic_preds)
except AttributeError:
self._feature_space = id_to_symbolic(self.model.states[0], self.state_valuations).feature_space
#fs = []
#adv_dist = 3
#for i in range(0, len(self._feature_space)):
# fs.append(self._feature_space[i])
# feature = self._feature_space[i]
# if i >= adv_dist and feature.startswith("view"):
# color = feature.replace("view", '')
# fs.append(f"dist{color}")
#self._feature_space = fs
return self._feature_space
@property
def num_features(self, only_basic_preds=False) -> int:
try:
return self._num_features
except AttributeError:
self._num_features = len(self.feature_space, only_basic_preds)
return self._num_features
@property
def length_shield(self) -> int:
return len(self.action_dictionary)