import sys import operator from os import listdir, system import subprocess import re from collections import defaultdict from random import randrange from ale_py import ALEInterface, SDL_SUPPORT, Action from PIL import Image from matplotlib import pyplot as plt import cv2 import pickle import queue from dataclasses import dataclass, field from sklearn.cluster import KMeans, DBSCAN from enum import Enum from copy import deepcopy import numpy as np import logging logger = logging.getLogger(__name__) #import readchar from sample_factory.algo.utils.tensor_dict import TensorDict from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper import time tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin" def tic(): import time global startTime_for_tictoc startTime_for_tictoc = time.time() def toc(): import time if 'startTime_for_tictoc' in globals(): return time.time() - startTime_for_tictoc class Verdict(Enum): INCONCLUSIVE = 1 GOOD = 2 BAD = 3 verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"} def convert(tuples): return dict(tuples) @dataclass(frozen=True) class State: x: int y: int ski_position: int #velocity: int def default_value(): return {'action' : None, 'choiceValue' : None} @dataclass(frozen=True) class StateValue: ranking: float choices: dict = field(default_factory=default_value) def exec(command,verbose=True): if verbose: print(f"Executing {command}") system(f"echo {command} >> list_of_exec") return system(command) num_tests_per_cluster = 50 factor_tests_per_cluster = 0.2 num_ski_positions = 8 def input_to_action(char): if char == "0": return Action.NOOP if char == "1": return Action.RIGHT if char == "2": return Action.LEFT if char == "3": return "reset" if char == "4": return "set_x" if char == "5": return "set_vel" if char in ["w", "a", "s", "d"]: return char def saveObservations(observations, verdict, testDir): testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}" if len(observations) < 20: logger.warn(f"Potentially spurious test case for {testDir}") testDir = f"{testDir}_pot_spurious" exec(f"mkdir {testDir}", verbose=False) for i, obs in enumerate(observations): img = Image.fromarray(obs) img.save(f"{testDir}/{i:003}.png") ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } #def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50): #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") testDir = f"{x}_{y}_{ski_position}"#_{velocity}" for i, r in enumerate(ramDICT[y]): ale.setRAM(i,r) ski_position_setting = ski_position_counter[ski_position] for i in range(0,ski_position_setting[1]): ale.act(ski_position_setting[0]) ale.setRAM(14,0) ale.setRAM(25,x) ale.setRAM(14,180) # TODO all_obs = list() speed_list = list() resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) for i in range(0,4): all_obs.append(resized_obs) for i in range(0,duration-4): resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) all_obs.append(resized_obs) if i % 4 == 0: stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) action = nn_wrapper.query(stack_tensor) ale.act(input_to_action(str(action))) else: ale.act(input_to_action(str(action))) speed_list.append(ale.getRAM()[14]) if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: saveObservations(all_obs, Verdict.BAD, testDir) return Verdict.BAD saveObservations(all_obs, Verdict.GOOD, testDir) return Verdict.GOOD def computeStateRanking(mdp_file): logger.info("Computing state ranking") tic() try: command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop 'Rmax=? [C <= 1000]'" result = subprocess.run(command, shell=True, check=True) print(result) except Exception as e: print(e) sys.exit(-1) logger.info(f"Computing state ranking - DONE: took {toc()} seconds") def fillStateRanking(file_name, match=""): logger.info("Parsing state ranking") tic() state_ranking = dict() try: with open(file_name, "r") as f: file_content = f.readlines() for line in file_content: if not "move=0" in line: continue ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) if ranking_value <= 0.1: continue stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) #print("stateMapping", stateMapping) choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) choices = {key:float(value) for (key,value) in choices.items()} #print("choices", choices) #print("ranking_value", ranking_value) state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))#, int(stateMapping["velocity"])) value = StateValue(ranking_value, choices) state_ranking[state] = value logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") return state_ranking except EnvironmentError: print("Ranking file not available. Exiting.") toc() sys.exit(1) except: toc() def createDisjunction(formulas): return " | ".join(formulas) def clusterFormula(cluster): formula = "" #states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster] states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] skiPositionGroup = defaultdict(list) for item in states: skiPositionGroup[item[2]].append(item) first = True #todo add velocity here for skiPosition, group in skiPositionGroup.items(): formula += f"ski_position={skiPosition} & (" yPosGroup = defaultdict(list) for item in group: yPosGroup[item[1]].append(item) for y, y_group in yPosGroup.items(): if first: first = False else: formula += " | " sorted_y_group = sorted(y_group, key=lambda s: s[0]) formula += f"( y={y} & (" current_x_min = sorted_y_group[0][0] current_x = sorted_y_group[0][0] x_ranges = list() for state in sorted_y_group[1:-1]: if state[0] - current_x == 1: current_x = state[0] else: x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") current_x_min = state[0] current_x = state[0] x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") formula += " | ".join(x_ranges) formula += ") )" formula += ")" return formula def createBalancedDisjunction(indices, name): #logger.info(f"Creating balanced disjunction for {len(indices)} ({indices}) formulas") if len(indices) == 0: return f"formula {name} = false;\n" else: while len(indices) > 1: indices_tmp = [f"({indices[i]} | {indices[i+1]})" for i in range(0,len(indices)//2)] if len(indices) % 2 == 1: indices_tmp.append(indices[-1]) indices = indices_tmp disjunction = f"formula {name} = " + " ".join(indices) + ";\n" return disjunction def createUnsafeFormula(clusters): label = "label \"Unsafe\" = Unsafe;\n" formulas = "" indices = list() for i, cluster in enumerate(clusters): formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n" indices.append(f"Unsafe_{i}") return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe") + label def createSafeFormula(clusters): label = "label \"Safe\" = Safe;\n" formulas = "" indices = list() for i, cluster in enumerate(clusters): formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" indices.append(f"Safe_{i}") return formulas + "\n" + createBalancedDisjunction(indices, "Safe") + label def updatePrismFile(newFile, iteration, safeStates, unsafeStates): logger.info("Creating next prism file") tic() initFile = f"{newFile}_no_formulas.prism" newFile = f"{newFile}_{iteration:03}.prism" exec(f"cp {initFile} {newFile}", verbose=False) with open(newFile, "a") as prism: prism.write(createSafeFormula(safeStates)) prism.write(createUnsafeFormula(unsafeStates)) logger.info(f"Creating next prism file - DONE: took {toc()} seconds") ale = ALEInterface() #if SDL_SUPPORT: # ale.setBool("sound", True) # ale.setBool("display_screen", True) # Load the ROM file ale.loadROM(rom_file) with open('all_positions_v2.pickle', 'rb') as handle: ramDICT = pickle.load(handle) y_ram_setting = 60 x = 70 nn_wrapper = SampleFactoryNNQueryWrapper() experiment_id = int(time.time()) init_mdp = "velocity_safety" exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) markerSize = 1 imagesDir = f"images/testing_{experiment_id}" def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0): markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} for state in states: s = state[0] marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '" markerList[s.ski_position].append(marker) for pos, marker in markerList.items(): command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png" exec(command, verbose=False) def concatImages(prefix, iteration): exec(f"montage {imagesDir}/{prefix}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False) exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False) def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): """ Useful to draw a set of states, e.g. a single cluster """ markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} logger.info(f"Drawing {len(states)} states onto {target}") tic() for state in states: s = state[0] marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" markerList[s.ski_position].append(marker) for pos, marker in markerList.items(): command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png" exec(command, verbose=False) exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): for ski_position in range(1, num_ski_positions + 1): source = "images/1_full_scaled_down.png" exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_individual.png", verbose=False) for _, clusterStates in clusterDict.items(): color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) concatImages(target, iteration) def drawResult(clusterDict, target, iteration): for ski_position in range(1, num_ski_positions + 1): source = "images/1_full_scaled_down.png" exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_individual.png") for _, (clusterStates, result) in clusterDict.items(): color = "100,100,100" if result == Verdict.GOOD: color = "0,200,0" elif result == Verdict.BAD: color = "200,0,0" drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) concatImages(target, iteration) def _init_logger(): logger = logging.getLogger('main') logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) def clusterImportantStates(ranking, iteration): logger.info(f"Starting to cluster {len(ranking)} states into clusters") tic() #states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[0].velocity * 10, s[1].ranking] for s in ranking] states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking] #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states) dbscan = DBSCAN(eps=15).fit(states) labels = dbscan.labels_ print(labels) n_clusters = len(set(labels)) - (1 if -1 in labels else 0) logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster") clusterDict = {i : list() for i in range(0,n_clusters)} for i, state in enumerate(ranking): if labels[i] == -1: continue clusterDict[labels[i]].append(state) drawClusters(clusterDict, f"clusters", iteration) return clusterDict if __name__ == '__main__': _init_logger() logger = logging.getLogger('main') logger.info("Starting") n_clusters = 40 testAll = False safeStates = list() unsafeStates = list() iteration = 0 while True: updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) computeStateRanking(f"{init_mdp}_{iteration:03}.prism") ranking = fillStateRanking("action_ranking") sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) clusters = clusterImportantStates(sorted_ranking, iteration) if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} clusterResult = dict() for id, cluster in clusters.items(): num_tests = int(factor_tests_per_cluster * len(cluster)) logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}") randomStates = np.random.choice(len(cluster), num_tests, replace=False) randomStates = [cluster[i] for i in randomStates] verdictGood = True for state in randomStates: x = state[0].x y = state[0].y ski_pos = state[0].ski_position #velocity = state[0].velocity result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50) if result == Verdict.BAD: if testAll: failingPerCluster[id].append(state) else: clusterResult[id] = (cluster, Verdict.BAD) verdictGood = False unsafeStates.append(cluster) break if verdictGood: clusterResult[id] = (cluster, Verdict.GOOD) safeStates.append(cluster) if testAll: drawClusters(failingPerCluster, f"failing", iteration) drawResult(clusterResult, "result", iteration) iteration += 1