diff --git a/rom_evaluate.py b/rom_evaluate.py index 0dec1ba..3d71a3a 100644 --- a/rom_evaluate.py +++ b/rom_evaluate.py @@ -203,6 +203,20 @@ def clusterFormula(cluster): return "(" + formulas[0] + ")" +def clusterFormulaXY(cluster): + if len(cluster) == 0: return + formulas = set() + for state in cluster: + formulas.add(f"(x={state[0].x} & y={state[0].y})") + formulas = list(formulas) + while len(formulas) > 1: + formulas_tmp = [f"({formulas[i]} | {formulas[i+1]})" for i in range(0,len(formulas)//2)] + if len(formulas) % 2 == 1: + formulas_tmp.append(formulas[-1]) + formulas = formulas_tmp + + return "(" + formulas[0] + ")" + def clusterFormulaTrimmed(cluster): formula = "" states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster] @@ -271,19 +285,19 @@ def createUnsafeFormula(clusters): formulas = "" indices = list() for i, cluster in enumerate(clusters): - formulas += f"formula Unsafe_{i} = {clusterFormulaTrimmed(cluster)};\n" + formulas += f"formula Unsafe_{i} = {clusterFormulaXY(cluster)};\n" indices.append(f"Unsafe_{i}") - return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe")# + label + return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe") + label def createSafeFormula(clusters): label = "label \"Safe\" = Safe;\n" formulas = "" indices = list() for i, cluster in enumerate(clusters): - formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" + formulas += f"formula Safe_{i} = {clusterFormulaXY(cluster)};\n" indices.append(f"Safe_{i}") - return formulas + "\n" + createBalancedDisjunction(indices, "Safe")# + label + return formulas + "\n" + createBalancedDisjunction(indices, "Safe") + label def updatePrismFile(newFile, iteration, safeStates, unsafeStates): logger.info("Creating next prism file") @@ -317,7 +331,7 @@ x = 70 nn_wrapper = SampleFactoryNNQueryWrapper() experiment_id = int(time.time()) -init_mdp = "velocity_safety" +init_mdp = "safety" exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) markerSize = 1