import sys import operator from os import listdir, system import subprocess import re from collections import defaultdict from random import randrange from ale_py import ALEInterface, SDL_SUPPORT, Action from PIL import Image from matplotlib import pyplot as plt import cv2 import pickle import queue from dataclasses import dataclass, field from sklearn.cluster import KMeans, DBSCAN from enum import Enum from copy import deepcopy import numpy as np import logging logger = logging.getLogger(__name__) #import readchar from sample_factory.algo.utils.tensor_dict import TensorDict from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper import time tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin" def tic(): import time global startTime_for_tictoc startTime_for_tictoc = time.time() def toc(): import time if 'startTime_for_tictoc' in globals(): return time.time() - startTime_for_tictoc class Verdict(Enum): INCONCLUSIVE = 1 GOOD = 2 BAD = 3 verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"} def convert(tuples): return dict(tuples) @dataclass(frozen=True) class State: x: int y: int ski_position: int velocity: int def default_value(): return {'action' : None, 'choiceValue' : None} @dataclass(frozen=True) class StateValue: ranking: float choices: dict = field(default_factory=default_value) @dataclass(frozen=False) class TestResult: init_check_pes_min: float init_check_pes_max: float init_check_pes_avg: float init_check_opt_min: float init_check_opt_max: float init_check_opt_avg: float safe_states: int unsafe_states: int policy_queries: int def __str__(self): return f"""Test Result: init_check_pes_min: {self.init_check_pes_min} init_check_pes_max: {self.init_check_pes_max} init_check_pes_avg: {self.init_check_pes_avg} init_check_opt_min: {self.init_check_opt_min} init_check_opt_max: {self.init_check_opt_max} init_check_opt_avg: {self.init_check_opt_avg} """ def csv(self, ws=" "): return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}{self.safeStates}{ws}{self.unsafeStates}{ws}{self.policy_queries}" def exec(command,verbose=True): if verbose: print(f"Executing {command}") system(f"echo {command} >> list_of_exec") return system(command) num_tests_per_cluster = 50 factor_tests_per_cluster = 0.2 num_ski_positions = 8 num_velocities = 5 def input_to_action(char): if char == "0": return Action.NOOP if char == "1": return Action.RIGHT if char == "2": return Action.LEFT if char == "3": return "reset" if char == "4": return "set_x" if char == "5": return "set_vel" if char in ["w", "a", "s", "d"]: return char def saveObservations(observations, verdict, testDir): testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}" if len(observations) < 20: logger.warn(f"Potentially spurious test case for {testDir}") testDir = f"{testDir}_pot_spurious" exec(f"mkdir {testDir}", verbose=False) for i, obs in enumerate(observations): img = Image.fromarray(obs) img.save(f"{testDir}/{i:003}.png") ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") testDir = f"{x}_{y}_{ski_position}_{velocity}" try: for i, r in enumerate(ramDICT[y]): ale.setRAM(i,r) ski_position_setting = ski_position_counter[ski_position] for i in range(0,ski_position_setting[1]): ale.act(ski_position_setting[0]) ale.setRAM(14,0) ale.setRAM(25,x) ale.setRAM(14,180) # TODO except Exception as e: print(e) logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}") return (Verdict.INCONCLUSIVE, 0) num_queries = 0 all_obs = list() speed_list = list() resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) for i in range(0,4): all_obs.append(resized_obs) for i in range(0,duration-4): resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) all_obs.append(resized_obs) if i % 4 == 0: stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) action = nn_wrapper.query(stack_tensor) num_queries += 1 ale.act(input_to_action(str(action))) else: ale.act(input_to_action(str(action))) speed_list.append(ale.getRAM()[14]) if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: #saveObservations(all_obs, Verdict.BAD, testDir) return (Verdict.BAD, num_queries) #saveObservations(all_obs, Verdict.GOOD, testDir) return (Verdict.GOOD, num_queries) def computeStateRanking(mdp_file, iteration): logger.info("Computing state ranking") tic() prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" prop += 'Rmax=? [C <= 200]' results = list() try: command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') for line in output: #print(line) if "Result" in line and not len(results) >= 6: range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) if range_value: results.append(float(range_value.group(2))) results.append(float(range_value.group(3))) else: value = re.search(r"(.*:)(.*)", line) results.append(float(value.group(2))) exec(f"mv action_ranking action_ranking_{iteration:03}") except subprocess.CalledProcessError as e: # todo die gracefully if ranking is uniform print(e.output) logger.info(f"Computing state ranking - DONE: took {toc()} seconds") return TestResult(*tuple(results),0,0,0) def fillStateRanking(file_name, match=""): logger.info(f"Parsing state ranking, {file_name}") tic() state_ranking = dict() try: with open(file_name, "r") as f: file_content = f.readlines() for line in file_content: if not "move=0" in line: continue ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) if ranking_value <= 0.1: continue stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) #print("stateMapping", stateMapping) choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) choices = {key:float(value) for (key,value) in choices.items()} #print("choices", choices) #print("ranking_value", ranking_value) state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2) #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) value = StateValue(ranking_value, choices) state_ranking[state] = value logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") return state_ranking except EnvironmentError: print("Ranking file not available. Exiting.") toc() sys.exit(-1) except: toc() def createDisjunction(formulas): return " | ".join(formulas) def statesFormulaTrimmed(states): if len(states) == 0: return "false" #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] skiPositionGroup = defaultdict(list) for item in states: skiPositionGroup[item[2]].append(item) formulas = list() for skiPosition, skiPos_group in skiPositionGroup.items(): formula = f"( ski_position={skiPosition} & " firstVelocity = True velocityGroup = defaultdict(list) for item in skiPos_group: velocityGroup[item[3]].append(item) for velocity, velocity_group in velocityGroup.items(): if firstVelocity: firstVelocity = False else: formula += " | " formula += f" (velocity={velocity} & " firstY = True yPosGroup = defaultdict(list) for item in velocity_group: yPosGroup[item[1]].append(item) for y, y_group in yPosGroup.items(): if firstY: firstY = False else: formula += " | " sorted_y_group = sorted(y_group, key=lambda s: s[0]) formula += f"( y={y} & (" current_x_min = sorted_y_group[0][0] current_x = sorted_y_group[0][0] x_ranges = list() for state in sorted_y_group[1:-1]: if state[0] - current_x == 1: current_x = state[0] else: x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") current_x_min = state[0] current_x = state[0] x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") formula += " | ".join(x_ranges) formula += ") )" formula += ")" formula += ")" formulas.append(formula) print(formulas) sys.exit(1) return createBalancedDisjunction(formulas) # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list def pairwise(iterable): "s -> (s0, s1), (s2, s3), (s4, s5), ..." a = iter(iterable) return zip(a, a) def createBalancedDisjunction(formulas): if len(formulas) == 0: return "false" while len(formulas) > 1: formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)] if len(formulas) % 2 == 1: formulas_tmp.append(formulas[-1]) formulas = formulas_tmp return " ".join(formulas) def updatePrismFile(newFile, iteration, safeStates, unsafeStates): logger.info("Creating next prism file") tic() initFile = f"{newFile}_no_formulas.prism" newFile = f"{newFile}_{iteration:03}.prism" exec(f"cp {initFile} {newFile}", verbose=False) with open(newFile, "a") as prism: prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n") prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n") prism.write(f"label \"Safe\" = Safe;\n") prism.write(f"label \"Unsafe\" = Unsafe;\n") logger.info(f"Creating next prism file - DONE: took {toc()} seconds") ale = ALEInterface() #if SDL_SUPPORT: # ale.setBool("sound", True) # ale.setBool("display_screen", True) # Load the ROM file ale.loadROM(rom_file) with open('all_positions_v2.pickle', 'rb') as handle: ramDICT = pickle.load(handle) y_ram_setting = 60 x = 70 nn_wrapper = SampleFactoryNNQueryWrapper() experiment_id = int(time.time()) init_mdp = "velocity_safety" exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) markerSize = 1 imagesDir = f"images/testing_{experiment_id}" def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0): #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)} images = dict() mergedImages = dict() for ski_position in range(1, num_ski_positions + 1): for velocity in range(0,num_velocities): images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png") mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png") for state in states: s = state[0] #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '" marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)] markerList[(s.ski_position, s.velocity)].append(marker) for (pos, vel), marker in markerList.items(): #command = f"convert {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png" #exec(command, verbose=False) if len(marker) == 0: continue for m in marker: images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED) mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED) for (ski_position, velocity), image in images.items(): cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image) for ski_position, image in mergedImages.items(): cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image) def concatImages(prefix, iteration): logger.info(f"Concatenating images") images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)] mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)] for vel in range(0, num_velocities): for pos in range(1, num_ski_positions + 1): command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png " command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' " command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" exec(command, verbose=False) exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False) exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}_merged.png", verbose=False) #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False) logger.info(f"Concatenating images - DONE") def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): """ Useful to draw a set of states, e.g. a single cluster markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} logger.info(f"Drawing {len(states)} states onto {target}") tic() for state in states: s = state[0] marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" markerList[s.ski_position].append(marker) for pos, marker in markerList.items(): command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png" exec(command, verbose=False) exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") """ def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): logger.info(f"Drawing {len(clusterDict)} clusters") tic() #for velocity in range(0, num_velocities): # for ski_position in range(1, num_ski_positions + 1): # source = "images/1_full_scaled_down.png" # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False) for _, clusterStates in clusterDict.items(): color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256))) color = (int(color[0]), int(color[1]), int(color[2])) drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) concatImages(target, iteration) logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds") def drawResult(clusterDict, target, iteration): logger.info(f"Drawing {len(clusterDict)} results") #for velocity in range(0,num_velocities): # for ski_position in range(1, num_ski_positions + 1): # source = "images/1_full_scaled_down.png" # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False) for _, (clusterStates, result) in clusterDict.items(): # opencv wants BGR color = (100,100,100) if result == Verdict.GOOD: color = (0,200,0) elif result == Verdict.BAD: color = (0,0,200) drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) concatImages(target, iteration) logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds") def _init_logger(): logger = logging.getLogger('main') logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) def clusterImportantStates(ranking, iteration): logger.info(f"Starting to cluster {len(ranking)} states into clusters") tic() states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking] #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking] #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states) dbscan = DBSCAN(eps=5).fit(states) labels = dbscan.labels_ n_clusters = len(set(labels)) - (1 if -1 in labels else 0) logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster") clusterDict = {i : list() for i in range(0,n_clusters)} strayStates = list() for i, state in enumerate(ranking): if labels[i] == -1: clusterDict[n_clusters + len(strayStates) + 1] = list() clusterDict[n_clusters + len(strayStates) + 1].append(state) strayStates.append(state) continue clusterDict[labels[i]].append(state) if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1") drawClusters(clusterDict, f"clusters", iteration) return clusterDict if __name__ == '__main__': _init_logger() logger = logging.getLogger('main') logger.info("Starting") testAll = False num_queries = 0 source = "images/1_full_scaled_down.png" for ski_position in range(1, num_ski_positions + 1): for velocity in range(0,num_velocities): exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False) exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False) exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False) exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False) safeStates = set() unsafeStates = set() iteration = 0 results = list() eps = 0.1 while True: updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) if len(results) > 0: modelCheckingResult.safeStates = results[-1].safeStates modelCheckingResult.unsafeStates = results[-1].unsafeStates modelCheckingResult.num_queries = results[-1].num_queries results.append(modelCheckingResult) logger.info(f"Model Checking Result: {modelCheckingResult}") if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps: logger.info(f"Absolute difference between average estimates is below eps = {eps}... finishing!") break ranking = fillStateRanking(f"action_ranking_{iteration:03}") sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) try: clusters = clusterImportantStates(sorted_ranking, iteration) except Exception as e: print(e) break if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} clusterResult = dict() logger.info(f"Running tests") tic() for id, cluster in clusters.items(): num_tests = int(factor_tests_per_cluster * len(cluster)) #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}") randomStates = np.random.choice(len(cluster), num_tests, replace=False) randomStates = [cluster[i] for i in randomStates] verdictGood = True for state in randomStates: x = state[0].x y = state[0].y ski_pos = state[0].ski_position velocity = state[0].velocity result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50) num_queries += num_queries_this_test_case if result == Verdict.BAD: if testAll: failingPerCluster[id].append(state) else: clusterResult[id] = (cluster, Verdict.BAD) verdictGood = False unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) break if verdictGood: clusterResult[id] = (cluster, Verdict.GOOD) safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}") results[-1].safeStates = len(safeStates) results[-1].unsafeStates = len(unsafeStates) results[-1].num_queries = num_queries if testAll: drawClusters(failingPerCluster, f"failing", iteration) drawResult(clusterResult, "result", iteration) iteration += 1 for result in results: print(result.csv())