| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -2,6 +2,8 @@ import sys | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import operator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from os import listdir, system | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import re | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from collections import defaultdict | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from random import randrange | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ale_py import ALEInterface, SDL_SUPPORT, Action | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#from colors import * | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -190,39 +192,76 @@ def fillStateRanking(file_name, match=""): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    except: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        toc() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def createDisjunction(formulas): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return " | ".join(formulas) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def clusterFormula(cluster): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    formula = "" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    skiPositionGroup = defaultdict(list) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for item in states: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        skiPositionGroup[item[2]].append(item) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    first = True | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for skiPosition, group in skiPositionGroup.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        formula += f"ski_position={skiPosition} & (" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        yPosGroup = defaultdict(list) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for item in group: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            yPosGroup[item[1]].append(item) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        for y, y_group in yPosGroup.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if first: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                first = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                formula += " | " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            sorted_y_group = sorted(y_group, key=lambda s: s[0]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            formula += f"( y={y} & (" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_x_min = sorted_y_group[0][0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            current_x = sorted_y_group[0][0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x_ranges = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            for state in sorted_y_group[1:-1]: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                if state[0] - current_x == 1: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    current_x = state[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    current_x_min = state[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    current_x = state[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            formula += " | ".join(x_ranges) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            formula += ") )" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        formula += ")" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        return formula | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def createUnsafeFormula(clusters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    formulas = "" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    disjunction = "formula Unsafe = false" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i, cluster in enumerate(clusters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        clusterFormula(cluster) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        disjunction += f" | Unsafe_{i}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    disjunction += ";\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return formulas + "\n" + disjunction | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def createSafeFormula(clusters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    formulas = "" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    disjunction = "formula Safe = false" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i, cluster in enumerate(clusters): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        disjunction += f" | Safe_{i}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    disjunction += ";\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return formulas + "\n" + disjunction | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def updatePrismFile(newFile, iteration, safeStates, unsafeStates): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info("Creating next prism file") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    tic() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    initFile = f"{newFile}_no_formulas.prism" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    newFile = f"{newFile}_{iteration:03}.prism" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    exec(f"cp {initFile} {newFile}", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    with open(newFile, "a") as prism: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        prism.write(createSafeFormula(safeStates)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        prism.write(createUnsafeFormula(unsafeStates)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"Creating next prism file - DONE: took {toc()} seconds") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					fixed_left_states = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					fixed_right_states = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					fixed_noop_states = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def populate_fixed_actions(state, action): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if action == Action.LEFT: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_left_states.append(state) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if action == Action.RIGHT: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_right_states.append(state) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if action == Action.NOOP: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_noop_states.append(state) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def update_prism_file(old_prism_file, new_prism_file): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_left_formula = "formula Fixed_Left = false " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_right_formula = "formula Fixed_Right = false " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_noop_formula = "formula Fixed_Noop = false " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for state in fixed_left_states: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for state in fixed_right_states: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for state in fixed_noop_states: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_left_formula += ";\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_right_formula += ";\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    fixed_noop_formula += ";\n" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    with open(f'{old_prism_file}', 'r') as file : | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					      filedata = file.read() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if len(fixed_left_states)  > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if len(fixed_noop_states)  > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    with open(f'{new_prism_file}', 'w') as file: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					      file.write(filedata) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					ale = ALEInterface() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -245,9 +284,9 @@ nn_wrapper = SampleFactoryNNQueryWrapper() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					iteration = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					experiment_id = int(time.time()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					init_mdp = "velocity_safety" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					exec(f"mkdir -p images/testing_{experiment_id}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					markerSize = 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -265,9 +304,9 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1. | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        exec(command, verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def concatImages(prefix): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}.png", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    exec(f"sxiv {imagesDir}/{prefix}.png&") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def concatImages(prefix, iteration): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    #exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    """ | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -286,16 +325,16 @@ def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def drawClusters(clusterDict, target, alpha_factor=1.0): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for ski_position in range(1, num_ski_positions + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        source = "images/1_full_scaled_down.png" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for _, clusterStates in clusterDict.items(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    concatImages(target) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    concatImages(target, iteration) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def drawResult(clusterDict, target): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def drawResult(clusterDict, target, iteration): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for ski_position in range(1, num_ski_positions + 1): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        source = "images/1_full_scaled_down.png" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -306,7 +345,7 @@ def drawResult(clusterDict, target): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        elif result == Verdict.BAD: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            color = "200,0,0" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    concatImages(target) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    concatImages(target, iteration) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def _init_logger(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger = logging.getLogger('main') | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -316,7 +355,7 @@ def _init_logger(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    handler.setFormatter(formatter) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.addHandler(handler) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def clusterImportantStates(ranking, n_clusters=40): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def clusterImportantStates(ranking, iteration, n_clusters=40): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    tic() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -325,20 +364,25 @@ def clusterImportantStates(ranking, n_clusters=40): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    clusterDict = {i : list() for i in range(0,n_clusters)} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    for i, state in enumerate(ranking): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        clusterDict[kmeans.labels_[i]].append(state) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    drawClusters(clusterDict, f"clusters") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    drawClusters(clusterDict, f"clusters", iteration) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return clusterDict | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if __name__ == '__main__': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    _init_logger() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger = logging.getLogger('main') | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    logger.info("Starting") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    n_clusters = 40 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    n_clusters = 2 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    testAll = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    safeStates = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    unsafeStates = list() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    iteration = 0 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    while True: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        #computeStateRanking(f"{init_mdp}_{iteration:03}.prism") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ranking = fillStateRanking("action_ranking") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        clusters = clusterImportantStates(sorted_ranking, n_clusters) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        clusterResult = dict() | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -361,10 +405,13 @@ if __name__ == '__main__': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                    else: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        clusterResult[id] = (cluster, Verdict.BAD) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        verdictGood = False | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        unsafeStates.append(cluster) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        break | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            if verdictGood: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                clusterResult[id] = (cluster, Verdict.GOOD) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                safeStates.append(cluster) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        if testAll: drawClusters(failingPerCluster, f"failing") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        drawResult(clusterResult, "result") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        drawResult(clusterResult, "result", iteration) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        iteration += 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism") |