diff --git a/rom_evaluate.py b/rom_evaluate.py index 5b6907a..c1fc4d5 100644 --- a/rom_evaluate.py +++ b/rom_evaluate.py @@ -2,6 +2,8 @@ import sys import operator from os import listdir, system import re +from collections import defaultdict + from random import randrange from ale_py import ALEInterface, SDL_SUPPORT, Action #from colors import * @@ -190,39 +192,76 @@ def fillStateRanking(file_name, match=""): except: toc() +def createDisjunction(formulas): + return " | ".join(formulas) + +def clusterFormula(cluster): + formula = "" + states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] + skiPositionGroup = defaultdict(list) + for item in states: + skiPositionGroup[item[2]].append(item) + + first = True + for skiPosition, group in skiPositionGroup.items(): + formula += f"ski_position={skiPosition} & (" + yPosGroup = defaultdict(list) + for item in group: + yPosGroup[item[1]].append(item) + for y, y_group in yPosGroup.items(): + if first: + first = False + else: + formula += " | " + sorted_y_group = sorted(y_group, key=lambda s: s[0]) + formula += f"( y={y} & (" + current_x_min = sorted_y_group[0][0] + current_x = sorted_y_group[0][0] + x_ranges = list() + for state in sorted_y_group[1:-1]: + if state[0] - current_x == 1: + current_x = state[0] + else: + x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") + current_x_min = state[0] + current_x = state[0] + x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") + formula += " | ".join(x_ranges) + formula += ") )" + formula += ")" + return formula + +def createUnsafeFormula(clusters): + formulas = "" + disjunction = "formula Unsafe = false" + for i, cluster in enumerate(clusters): + formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n" + clusterFormula(cluster) + disjunction += f" | Unsafe_{i}" + disjunction += ";\n" + return formulas + "\n" + disjunction + +def createSafeFormula(clusters): + formulas = "" + disjunction = "formula Safe = false" + for i, cluster in enumerate(clusters): + formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" + disjunction += f" | Safe_{i}" + disjunction += ";\n" + return formulas + "\n" + disjunction + +def updatePrismFile(newFile, iteration, safeStates, unsafeStates): + logger.info("Creating next prism file") + tic() + initFile = f"{newFile}_no_formulas.prism" + newFile = f"{newFile}_{iteration:03}.prism" + exec(f"cp {initFile} {newFile}", verbose=False) + with open(newFile, "a") as prism: + prism.write(createSafeFormula(safeStates)) + prism.write(createUnsafeFormula(unsafeStates)) + + logger.info(f"Creating next prism file - DONE: took {toc()} seconds") -fixed_left_states = list() -fixed_right_states = list() -fixed_noop_states = list() - -def populate_fixed_actions(state, action): - if action == Action.LEFT: - fixed_left_states.append(state) - if action == Action.RIGHT: - fixed_right_states.append(state) - if action == Action.NOOP: - fixed_noop_states.append(state) - -def update_prism_file(old_prism_file, new_prism_file): - fixed_left_formula = "formula Fixed_Left = false " - fixed_right_formula = "formula Fixed_Right = false " - fixed_noop_formula = "formula Fixed_Noop = false " - for state in fixed_left_states: - fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " - for state in fixed_right_states: - fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " - for state in fixed_noop_states: - fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " - fixed_left_formula += ";\n" - fixed_right_formula += ";\n" - fixed_noop_formula += ";\n" - with open(f'{old_prism_file}', 'r') as file : - filedata = file.read() - if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE) - if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE) - if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE) - with open(f'{new_prism_file}', 'w') as file: - file.write(filedata) ale = ALEInterface() @@ -245,9 +284,9 @@ nn_wrapper = SampleFactoryNNQueryWrapper() iteration = 0 experiment_id = int(time.time()) init_mdp = "velocity_safety" -exec(f"mkdir -p images/testing_{experiment_id}") -exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") -exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") +exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) +#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") +#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") markerSize = 1 #markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} @@ -265,9 +304,9 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1. exec(command, verbose=False) -def concatImages(prefix): - exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}.png", verbose=False) - exec(f"sxiv {imagesDir}/{prefix}.png&") +def concatImages(prefix, iteration): + exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False) + #exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False) def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): """ @@ -286,16 +325,16 @@ def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") -def drawClusters(clusterDict, target, alpha_factor=1.0): +def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): for ski_position in range(1, num_ski_positions + 1): source = "images/1_full_scaled_down.png" exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") for _, clusterStates in clusterDict.items(): color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) - concatImages(target) + concatImages(target, iteration) -def drawResult(clusterDict, target): +def drawResult(clusterDict, target, iteration): for ski_position in range(1, num_ski_positions + 1): source = "images/1_full_scaled_down.png" exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") @@ -306,7 +345,7 @@ def drawResult(clusterDict, target): elif result == Verdict.BAD: color = "200,0,0" drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) - concatImages(target) + concatImages(target, iteration) def _init_logger(): logger = logging.getLogger('main') @@ -316,7 +355,7 @@ def _init_logger(): handler.setFormatter(formatter) logger.addHandler(handler) -def clusterImportantStates(ranking, n_clusters=40): +def clusterImportantStates(ranking, iteration, n_clusters=40): logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") tic() states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] @@ -325,20 +364,25 @@ def clusterImportantStates(ranking, n_clusters=40): clusterDict = {i : list() for i in range(0,n_clusters)} for i, state in enumerate(ranking): clusterDict[kmeans.labels_[i]].append(state) - drawClusters(clusterDict, f"clusters") + drawClusters(clusterDict, f"clusters", iteration) return clusterDict if __name__ == '__main__': _init_logger() logger = logging.getLogger('main') logger.info("Starting") - n_clusters = 40 + n_clusters = 2 testAll = False + + safeStates = list() + unsafeStates = list() + iteration = 0 while True: + updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) #computeStateRanking(f"{init_mdp}_{iteration:03}.prism") ranking = fillStateRanking("action_ranking") sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) - clusters = clusterImportantStates(sorted_ranking, n_clusters) + clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters) if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} clusterResult = dict() @@ -361,10 +405,13 @@ if __name__ == '__main__': else: clusterResult[id] = (cluster, Verdict.BAD) verdictGood = False + unsafeStates.append(cluster) break if verdictGood: clusterResult[id] = (cluster, Verdict.GOOD) + safeStates.append(cluster) if testAll: drawClusters(failingPerCluster, f"failing") - drawResult(clusterResult, "result") + drawResult(clusterResult, "result", iteration) + iteration += 1 + - update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")