diff --git a/rom_evaluate.py b/rom_evaluate.py index 3cf2978..657438a 100644 --- a/rom_evaluate.py +++ b/rom_evaluate.py @@ -170,22 +170,32 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): #saveObservations(all_obs, Verdict.GOOD, testDir) return (Verdict.GOOD, num_queries) +def skiPositionFormulaList(name): + formulas = list() + for i in range(1, num_ski_positions+1): + formulas.append(f"\"{name}_{i}\"") + return createBalancedDisjunction(formulas) + + def computeStateRanking(mdp_file, iteration): logger.info("Computing state ranking") tic() - prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" - prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" - prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" - prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" - prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" - prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" + prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" + prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" + prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" + prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" + prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" + prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" prop += 'Rmax=? [C <= 200]' results = list() try: command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') + num_states = 0 for line in output: #print(line) + if "States:" in line: + num_states = int(line.split(" ")[-1]) if "Result" in line and not len(results) >= 6: range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) if range_value: @@ -199,7 +209,7 @@ def computeStateRanking(mdp_file, iteration): # todo die gracefully if ranking is uniform print(e.output) logger.info(f"Computing state ranking - DONE: took {toc()} seconds") - return TestResult(*tuple(results),0,0,0) + return TestResult(*tuple(results),0,0,0), num_states def fillStateRanking(file_name, match=""): logger.info(f"Parsing state ranking, {file_name}") @@ -214,13 +224,9 @@ def fillStateRanking(file_name, match=""): if ranking_value <= 0.1: continue stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) - #print("stateMapping", stateMapping) choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) choices = {key:float(value) for (key,value) in choices.items()} - #print("choices", choices) - #print("ranking_value", ranking_value) state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2) - #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) value = StateValue(ranking_value, choices) state_ranking[state] = value logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") @@ -235,8 +241,7 @@ def fillStateRanking(file_name, match=""): def createDisjunction(formulas): return " | ".join(formulas) -def statesFormulaTrimmed(states): - if len(states) == 0: return "false" +def statesFormulaTrimmed(states, name): #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] skiPositionGroup = defaultdict(list) for item in states: @@ -244,7 +249,7 @@ def statesFormulaTrimmed(states): formulas = list() for skiPosition, skiPos_group in skiPositionGroup.items(): - formula = f"( ski_position={skiPosition} & " + formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & " firstVelocity = True velocityGroup = defaultdict(list) for item in skiPos_group: @@ -254,18 +259,16 @@ def statesFormulaTrimmed(states): firstVelocity = False else: formula += " | " + formulasPerSkiPosition = list() formula += f" (velocity={velocity} & " firstY = True yPosGroup = defaultdict(list) + yAndXRanges = dict() for item in velocity_group: yPosGroup[item[1]].append(item) for y, y_group in yPosGroup.items(): - if firstY: - firstY = False - else: - formula += " | " sorted_y_group = sorted(y_group, key=lambda s: s[0]) - formula += f"( y={y} & (" + #formula += f"( y={y} & (" current_x_min = sorted_y_group[0][0] current_x = sorted_y_group[0][0] x_ranges = list() @@ -273,18 +276,47 @@ def statesFormulaTrimmed(states): if state[0] - current_x == 1: current_x = state[0] else: - x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") + x_ranges.append(f" ({current_x_min}<=x&x<={current_x})") current_x_min = state[0] current_x = state[0] - x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") - formula += " | ".join(x_ranges) - formula += ") )" + x_ranges.append(f" ({current_x_min}<=x&x<={sorted_y_group[-1][0]})") + xRangesDisjunction = createBalancedDisjunction(x_ranges) + if xRangesDisjunction in yAndXRanges: + yAndXRanges[xRangesDisjunction].append(y) + else: + yAndXRanges[xRangesDisjunction] = list() + yAndXRanges[xRangesDisjunction].append(y) + for xRange, ys in yAndXRanges.items(): + #if firstY: + # firstY = False + #else: + # formula += " | " + sorted_ys = sorted(ys) + if len(ys) == 1: + formulasPerSkiPosition.append(f" ({xRange} & y={ys[0]})") + continue + current_y_min = sorted_ys[0] + current_y = sorted_ys[0] + y_ranges = list() + for y in sorted_ys[1:]: + if y - current_y == 2: + current_y = y + elif abs(y - current_y) > 2: + y_ranges.append(f" ({current_y_min}<=y&y<={current_y})") + current_y_min = y + current_y = y + y_ranges.append(f" ({current_y_min}<=y&y<={sorted_ys[-1]})") + formulasPerSkiPosition.append(f" ({xRange} & ({createBalancedDisjunction(y_ranges)}))") + + formula += createBalancedDisjunction(formulasPerSkiPosition) formula += ")" - formula += ")" + formula += ");" formulas.append(formula) - print(formulas) - sys.exit(1) - return createBalancedDisjunction(formulas) + for i in range(1, num_ski_positions+1): + if i in skiPositionGroup: + continue + formulas.append(f"formula {name}_{i} = false;") + return "\n".join(formulas) + "\n" # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list def pairwise(iterable): @@ -309,10 +341,11 @@ def updatePrismFile(newFile, iteration, safeStates, unsafeStates): newFile = f"{newFile}_{iteration:03}.prism" exec(f"cp {initFile} {newFile}", verbose=False) with open(newFile, "a") as prism: - prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n") - prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n") - prism.write(f"label \"Safe\" = Safe;\n") - prism.write(f"label \"Unsafe\" = Unsafe;\n") + prism.write(statesFormulaTrimmed(safeStates, "Safe")) + prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe")) + for i in range(1,num_ski_positions+1): + prism.write(f"label \"Safe_{i}\" = Safe_{i};\n") + prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n") logger.info(f"Creating next prism file - DONE: took {toc()} seconds") @@ -486,11 +519,11 @@ if __name__ == '__main__': eps = 0.1 while True: updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) - modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) + modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) if len(results) > 0: - modelCheckingResult.safeStates = results[-1].safeStates - modelCheckingResult.unsafeStates = results[-1].unsafeStates - modelCheckingResult.num_queries = results[-1].num_queries + modelCheckingResult.safeStates = results[-1].safeStates + modelCheckingResult.unsafeStates = results[-1].unsafeStates + modelCheckingResult.policy_queries = results[-1].policy_queries results.append(modelCheckingResult) logger.info(f"Model Checking Result: {modelCheckingResult}") if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps: @@ -538,7 +571,13 @@ if __name__ == '__main__': logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}") results[-1].safeStates = len(safeStates) results[-1].unsafeStates = len(unsafeStates) - results[-1].num_queries = num_queries + results[-1].policy_queries = num_queries + # Account for self-loop states after first iteration + if iteration > 0: + results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafeStates + 0.0*results[-2].safeStates) + results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafeStates + 1.0*results[-2].safeStates) + for result in results: + print(result.csv()) if testAll: drawClusters(failingPerCluster, f"failing", iteration) drawResult(clusterResult, "result", iteration)