You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
11 KiB
232 lines
11 KiB
from shield_handler import ShieldHandler
|
|
import tempfile, datetime, shutil
|
|
import os
|
|
import stormpy
|
|
import stormpy.core
|
|
import stormpy.simulator
|
|
from itertools import chain, combinations
|
|
import re
|
|
from minigrid.core.state import to_state, State
|
|
|
|
from utils import tic, toc, get_allowed_actions_mask_for_labels, id_to_symbolic, StateClassification
|
|
|
|
from loguru import logger
|
|
|
|
class MiniGridShieldHandler(ShieldHandler):
|
|
ACTION_SPACE_SIZE = 4
|
|
def __init__(self, grid_to_prism_binary, grid_file, prism_path, formula, env, prism_config=None, shield_value=0.9, shield_comparison='absolute', ignore_view: bool = False, nocleanup=False, prism_file=None) -> None:
|
|
logger.success("MiniGridShieldHandler ctor called")
|
|
os.makedirs("shielding_files", exist_ok=True)
|
|
self.tmp_dir_name = f"shielding_files/{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{next(tempfile._get_candidate_names())}"
|
|
os.makedirs(self.tmp_dir_name)
|
|
self.grid_file = self.tmp_dir_name + "/" + grid_file
|
|
self.grid_to_prism_binary = grid_to_prism_binary
|
|
self.prism_path = self.tmp_dir_name + "/" + prism_path
|
|
self.prism_config = prism_config
|
|
self.prism_file = prism_file
|
|
self.action_dictionary = None
|
|
self.ignore_view = ignore_view
|
|
self.env = env
|
|
|
|
self.dtcontrol_csvs = dict()
|
|
self.dtcontrol_csvs["base"] = set()
|
|
self.dtcontrol_csvs["base_basic_preds"] = set()
|
|
self.dtcontrol_csvs["formulas_critical"] = list()
|
|
self.dtcontrol_csvs["formulas_dangerous"] = list()
|
|
self.dtcontrol_csvs["safe_unsafe"] = set()
|
|
self.dtcontrol_csvs["critical_base"] = list()
|
|
self.dtcontrol_csvs["dangerous_base"] = list()
|
|
self.dtcontrol_csvs["critical_state_action"] = list()
|
|
self.dtcontrol_csvs["dangerous_state_action"] = list()
|
|
|
|
|
|
self.formula = formula
|
|
self.shield_comparison = stormpy.logic.ShieldComparison.ABSOLUTE if shield_comparison == "absolute" else stormpy.logic.ShieldComparison.RELATIVE
|
|
self.shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, shield_value)
|
|
self.fallback_shield_expression = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, self.shield_comparison, 1.0)
|
|
self.shield_comparison = shield_comparison
|
|
self.nocleanup = nocleanup
|
|
|
|
|
|
def __del__(self):
|
|
pass
|
|
#if not self.nocleanup:
|
|
# shutil.rmtree(self.tmp_dir_name)
|
|
|
|
def __export_grid_to_text(self, env):
|
|
with open(self.grid_file, "w") as f:
|
|
f.write(env.printGrid(init=True))
|
|
|
|
|
|
def __create_prism(self):
|
|
if self.prism_file is not None:
|
|
logger.info(self.prism_file)
|
|
logger.info(self.prism_path)
|
|
shutil.copyfile(self.prism_file, self.prism_path)
|
|
return
|
|
if self.prism_config is None:
|
|
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path}")
|
|
logger.info(self.grid_file)
|
|
logger.info(self.prism_path)
|
|
else:
|
|
result = os.system(F"{self.grid_to_prism_binary} -i {self.grid_file} -o {self.prism_path} -c {self.prism_config}")
|
|
|
|
assert result == 0, "Prism file could not be generated"
|
|
|
|
def __create_shield_dict(self):
|
|
self.program = stormpy.parse_prism_program(self.prism_path)
|
|
|
|
f_comb = self.formula[0] + " | ".join(self.formula[1:-1]) + self.formula[-1]
|
|
f_sep = [self.formula[0] + f + self.formula[-1] for f in self.formula[1:-1]]
|
|
s = f_sep[1:-1]
|
|
|
|
#formula_powerset = list(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1)))
|
|
f_sep.append(f_comb)
|
|
formulas_comb = stormpy.parse_properties_for_prism_program(f_comb, self.program)
|
|
formulas_sep = stormpy.parse_properties_for_prism_program("; ".join(f_sep), self.program)
|
|
options = stormpy.BuilderOptions([p.raw_formula for p in formulas_sep])
|
|
options.set_build_state_valuations(True)
|
|
options.set_build_choice_labels(True)
|
|
options.set_build_all_labels()
|
|
print(f"LOG: Starting with explicit model creation...")
|
|
tic()
|
|
model = stormpy.build_sparse_model_with_options(self.program, options)
|
|
self.model = model
|
|
toc()
|
|
|
|
print(f"LOG: Starting with model checking...")
|
|
tic()
|
|
result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.shield_expression)
|
|
fallback_result = stormpy.model_checking(model, formulas_comb[0], extract_scheduler=True, shield_expression=self.fallback_shield_expression)
|
|
toc()
|
|
|
|
assert result.has_shield
|
|
shield = result.shield
|
|
fallback_shield = fallback_result.shield
|
|
self.fallback_shield = fallback_shield.construct()
|
|
action_dictionary = dict()
|
|
self.shield_scheduler = shield.construct()
|
|
self.state_valuations = model.state_valuations
|
|
self.choice_labeling = model.choice_labeling
|
|
|
|
stormpy.shields.export_shield(model, shield, self.tmp_dir_name + "/shield")
|
|
stormpy.shields.export_shield(model, fallback_shield, self.tmp_dir_name + "/fallback_shield")
|
|
self.dangerous_states = list()
|
|
|
|
print(f"LOG: Starting to translate shield...")
|
|
tic()
|
|
for state_id in model.states:
|
|
state = id_to_symbolic(state_id, self.state_valuations)
|
|
if state == None:
|
|
continue
|
|
if 'deadlock' in state_id.labels or self.formula[1] in state_id.labels:
|
|
continue
|
|
choice = self.shield_scheduler.get_choice(state_id)
|
|
choices = choice.choice_map
|
|
|
|
state = id_to_symbolic(state_id, self.state_valuations)
|
|
if not state:
|
|
continue
|
|
action_dictionary[state] = get_allowed_actions_mask_for_labels([self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices], self.ignore_view)
|
|
is_dangerous = False
|
|
if sum(action_dictionary[state]) == 0:
|
|
choice = self.fallback_shield.get_choice(state_id)
|
|
choices = choice.choice_map
|
|
try:
|
|
logger.warning(f"State: [colAgent={state.colAgent}, rowAgent={state.rowAgent}] is dangerous...")
|
|
fallback_value = min(choices, key=lambda c : c[0])[0]
|
|
fallback_actions = [c[1] for c in choices if c[0] == fallback_value]
|
|
choice_labels = [self.choice_labeling.get_labels_of_choice(model.get_choice_index(state_id, choice[1])) for choice in choices]
|
|
enabled_fallback_actions_labels = [choice_labels[i] for i in fallback_actions]
|
|
for c in fallback_actions:
|
|
action_dictionary[state][c] = 1.0
|
|
action_dictionary[state] = get_allowed_actions_mask_for_labels(enabled_fallback_actions_labels, self.ignore_view)
|
|
logger.warning(f"... only allowing {enabled_fallback_actions_labels}")
|
|
self.dangerous_states.append(state)
|
|
|
|
except ValueError as e:
|
|
logger.error(f"{state=} is unsafe, {e=}")
|
|
continue
|
|
except IndexError as e:
|
|
logger.error(f"{state=} is unsafe, {e=}")
|
|
continue
|
|
|
|
|
|
toc()
|
|
self.action_dictionary = action_dictionary
|
|
logger.success(f"Length of the action_dictionary of the shield:\t\t{len(action_dictionary)}")
|
|
return action_dictionary
|
|
|
|
def create_shield(self, **kwargs):
|
|
if self.action_dictionary is not None:
|
|
return self.action_dictionary
|
|
|
|
env = kwargs["env"]
|
|
self.__export_grid_to_text(env)
|
|
logger.info("creating prism file")
|
|
self.__create_prism()
|
|
print("Computing new shield")
|
|
return self.__create_shield_dict()
|
|
|
|
def get_model(self):
|
|
return self.model
|
|
|
|
def get_dtcontrol_csv(self, action_comparison="absolute"):
|
|
pre = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables()} 1"]
|
|
pre_basic = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(only_basic_preds = True)} 1"]
|
|
pre_action = ["#PERMISSIVE", f"BEGIN {self.translator.number_state_variables(action=True)} 1"]
|
|
meta_data = dict()
|
|
meta_data_action = dict()
|
|
meta_data_basic = dict()
|
|
types, values = self.translator.get_column_types_and_categories()
|
|
meta_data["x_column_types"] = types
|
|
meta_data["x_category_names"] = values
|
|
|
|
types, values = self.translator.get_column_types_and_categories(action_comparison=action_comparison, action=True)
|
|
meta_data_action["x_column_types"] = types
|
|
meta_data_action["x_category_names"] = values
|
|
|
|
types, values = self.translator.get_column_types_and_categories(only_basic_preds=True)
|
|
meta_data_basic["x_column_types"] = types
|
|
meta_data_basic["x_category_names"] = values
|
|
|
|
result = dict()
|
|
for key in self.dtcontrol_csvs:
|
|
if key == "formulas_dangerous" or key == "formulas_critical":
|
|
result[key] = [pre + v for v in self.dtcontrol_csvs[key]]
|
|
elif "state_action" in key:
|
|
result[key] = pre_action + self.dtcontrol_csvs[key]
|
|
elif "basic_preds" in key:
|
|
result[key] = pre_basic + list(self.dtcontrol_csvs[key])
|
|
else:
|
|
result[key] = pre + list(self.dtcontrol_csvs[key])
|
|
return result, meta_data, meta_data_action, meta_data_basic
|
|
|
|
@property
|
|
def feature_space(self, only_basic_preds=False) -> list[str]:
|
|
try:
|
|
return self.translator.get_feature_space(only_basic_preds=only_basic_preds)
|
|
except AttributeError:
|
|
self._feature_space = id_to_symbolic(self.model.states[0], self.state_valuations).feature_space
|
|
#fs = []
|
|
#adv_dist = 3
|
|
#for i in range(0, len(self._feature_space)):
|
|
# fs.append(self._feature_space[i])
|
|
# feature = self._feature_space[i]
|
|
# if i >= adv_dist and feature.startswith("view"):
|
|
# color = feature.replace("view", '')
|
|
# fs.append(f"dist{color}")
|
|
#self._feature_space = fs
|
|
return self._feature_space
|
|
|
|
|
|
@property
|
|
def num_features(self, only_basic_preds=False) -> int:
|
|
try:
|
|
return self._num_features
|
|
except AttributeError:
|
|
self._num_features = len(self.feature_space, only_basic_preds)
|
|
return self._num_features
|
|
@property
|
|
def length_shield(self) -> int:
|
|
return len(self.action_dictionary)
|