diff --git a/rom_evaluate.py b/rom_evaluate.py index 7ef6064..7fb187b 100644 --- a/rom_evaluate.py +++ b/rom_evaluate.py @@ -109,7 +109,8 @@ ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.L #def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50): #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") - testDir = f"{x}_{y}_{ski_position}"#_{velocity}" + #testDir = f"{x}_{y}_{ski_position}_{velocity}" + testDir = f"{x}_{y}_{ski_position}" for i, r in enumerate(ramDICT[y]): ale.setRAM(i,r) ski_position_setting = ski_position_counter[ski_position] @@ -140,7 +141,7 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50): saveObservations(all_obs, Verdict.GOOD, testDir) return Verdict.GOOD -def computeStateRanking(mdp_file): +def computeStateRanking(mdp_file, iteration): logger.info("Computing state ranking") tic() try: @@ -150,10 +151,11 @@ def computeStateRanking(mdp_file): except Exception as e: print(e) sys.exit(-1) + exec(f"mv action_ranking action_ranking_{iteration:03}") logger.info(f"Computing state ranking - DONE: took {toc()} seconds") def fillStateRanking(file_name, match=""): - logger.info("Parsing state ranking") + logger.info(f"Parsing state ranking, {file_name}") tic() state_ranking = dict() try: @@ -170,7 +172,8 @@ def fillStateRanking(file_name, match=""): choices = {key:float(value) for (key,value) in choices.items()} #print("choices", choices) #print("ranking_value", ranking_value) - state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))#, int(stateMapping["velocity"])) + #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])) + state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) value = StateValue(ranking_value, choices) state_ranking[state] = value logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") @@ -299,8 +302,8 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1. markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} for state in states: s = state[0] - marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" - #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '" + #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" + marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '" markerList[s.ski_position].append(marker) for pos, marker in markerList.items(): command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png" @@ -366,7 +369,6 @@ def clusterImportantStates(ranking, iteration): #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states) dbscan = DBSCAN(eps=15).fit(states) labels = dbscan.labels_ - print(labels) n_clusters = len(set(labels)) - (1 if -1 in labels else 0) logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster") clusterDict = {i : list() for i in range(0,n_clusters)} @@ -388,8 +390,8 @@ if __name__ == '__main__': iteration = 0 while True: updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) - computeStateRanking(f"{init_mdp}_{iteration:03}.prism") - ranking = fillStateRanking("action_ranking") + computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) + ranking = fillStateRanking(f"action_ranking_{iteration:03}") sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) clusters = clusterImportantStates(sorted_ranking, iteration) @@ -397,6 +399,7 @@ if __name__ == '__main__': clusterResult = dict() for id, cluster in clusters.items(): num_tests = int(factor_tests_per_cluster * len(cluster)) + num_tests = 1 logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}") randomStates = np.random.choice(len(cluster), num_tests, replace=False) randomStates = [cluster[i] for i in randomStates] @@ -408,6 +411,8 @@ if __name__ == '__main__': ski_pos = state[0].ski_position #velocity = state[0].velocity result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50) + #result = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50) + result = Verdict.BAD # TODO REMOVE ME!!!!!!!!!!!!!! if result == Verdict.BAD: if testAll: failingPerCluster[id].append(state) @@ -419,7 +424,9 @@ if __name__ == '__main__': if verdictGood: clusterResult[id] = (cluster, Verdict.GOOD) safeStates.append(cluster) + logger.info(f"Iteration: {iteration:03} -\tSafe Results : {sum([len(c) for c in safeStates])} -\tUnsafe Results:{sum([len(c) for c in unsafeStates])}") if testAll: drawClusters(failingPerCluster, f"failing", iteration) + drawResult(clusterResult, "result", iteration) iteration += 1