3 changed files with 674 additions and 109 deletions
			
			
		@ -0,0 +1,605 @@ | 
			
		|||||
 | 
				import sys | 
			
		||||
 | 
				import operator | 
			
		||||
 | 
				from copy import deepcopy | 
			
		||||
 | 
				from os import listdir, system | 
			
		||||
 | 
				import subprocess | 
			
		||||
 | 
				import re | 
			
		||||
 | 
				from collections import defaultdict | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from random import randrange | 
			
		||||
 | 
				from ale_py import ALEInterface, SDL_SUPPORT, Action | 
			
		||||
 | 
				from PIL import Image | 
			
		||||
 | 
				from matplotlib import pyplot as plt | 
			
		||||
 | 
				import cv2 | 
			
		||||
 | 
				import pickle | 
			
		||||
 | 
				import queue | 
			
		||||
 | 
				from dataclasses import dataclass, field | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from sklearn.cluster import KMeans, DBSCAN | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from enum import Enum | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from copy import deepcopy | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				logger = logging.getLogger(__name__) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				#import readchar | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from sample_factory.algo.utils.tensor_dict import TensorDict | 
			
		||||
 | 
				from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import time | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" | 
			
		||||
 | 
				rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin" | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def tic(): | 
			
		||||
 | 
				    import time | 
			
		||||
 | 
				    global startTime_for_tictoc | 
			
		||||
 | 
				    startTime_for_tictoc = time.time() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def toc(): | 
			
		||||
 | 
				    import time | 
			
		||||
 | 
				    if 'startTime_for_tictoc' in globals(): | 
			
		||||
 | 
				        return time.time() - startTime_for_tictoc | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class Verdict(Enum): | 
			
		||||
 | 
				    INCONCLUSIVE = 1 | 
			
		||||
 | 
				    GOOD = 2 | 
			
		||||
 | 
				    BAD = 3 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"} | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def convert(tuples): | 
			
		||||
 | 
				    return dict(tuples) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				@dataclass(frozen=True) | 
			
		||||
 | 
				class State: | 
			
		||||
 | 
				    x: int | 
			
		||||
 | 
				    y: int | 
			
		||||
 | 
				    ski_position: int | 
			
		||||
 | 
				    velocity: int | 
			
		||||
 | 
				def default_value(): | 
			
		||||
 | 
				    return {'action' : None, 'choiceValue' : None} | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				@dataclass(frozen=True) | 
			
		||||
 | 
				class StateValue: | 
			
		||||
 | 
				    ranking: float | 
			
		||||
 | 
				    choices: dict = field(default_factory=default_value) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				@dataclass(frozen=False) | 
			
		||||
 | 
				class TestResult: | 
			
		||||
 | 
				    init_check_pes_min: float | 
			
		||||
 | 
				    init_check_pes_max: float | 
			
		||||
 | 
				    init_check_pes_avg: float | 
			
		||||
 | 
				    init_check_opt_min: float | 
			
		||||
 | 
				    init_check_opt_max: float | 
			
		||||
 | 
				    init_check_opt_avg: float | 
			
		||||
 | 
				    safe_states: int | 
			
		||||
 | 
				    unsafe_states: int | 
			
		||||
 | 
				    safe_cluster: int | 
			
		||||
 | 
				    unsafe_cluster: int | 
			
		||||
 | 
				    good_verdicts: int | 
			
		||||
 | 
				    bad_verdicts: int | 
			
		||||
 | 
				    policy_queries: int | 
			
		||||
 | 
				    def __str__(self): | 
			
		||||
 | 
				        return f"""Test Result: | 
			
		||||
 | 
				    init_check_pes_min: {self.init_check_pes_min} | 
			
		||||
 | 
				    init_check_pes_max: {self.init_check_pes_max} | 
			
		||||
 | 
				    init_check_pes_avg: {self.init_check_pes_avg} | 
			
		||||
 | 
				    init_check_opt_min: {self.init_check_opt_min} | 
			
		||||
 | 
				    init_check_opt_max: {self.init_check_opt_max} | 
			
		||||
 | 
				    init_check_opt_avg: {self.init_check_opt_avg} | 
			
		||||
 | 
				""" | 
			
		||||
 | 
				    @staticmethod | 
			
		||||
 | 
				    def csv_header(ws=" "): | 
			
		||||
 | 
				        string =  f"pesmin{ws}pesmax{ws}pesavg{ws}" | 
			
		||||
 | 
				        string += f"optmin{ws}optmax{ws}optavg{ws}" | 
			
		||||
 | 
				        string += f"sState{ws}uState{ws}" | 
			
		||||
 | 
				        string += f"sClust{ws}uClust{ws}" | 
			
		||||
 | 
				        string += f"gVerd{ws}bVerd{ws}queries" | 
			
		||||
 | 
				        return string | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def csv(self): | 
			
		||||
 | 
				        ws = " " | 
			
		||||
 | 
				        string =  f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}" | 
			
		||||
 | 
				        string += f"{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}" | 
			
		||||
 | 
				        ws = "\t" | 
			
		||||
 | 
				        string += f"{self.safe_states}{ws}{self.unsafe_states}{ws}" | 
			
		||||
 | 
				        string += f"{self.safe_cluster}{ws}{self.unsafe_cluster}{ws}" | 
			
		||||
 | 
				        string += f"{self.good_verdicts}{ws}{self.bad_verdicts}{ws}{self.policy_queries}" | 
			
		||||
 | 
				        return string | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def exec(command,verbose=True): | 
			
		||||
 | 
				    if verbose: print(f"Executing {command}") | 
			
		||||
 | 
				    system(f"echo {command} >> list_of_exec") | 
			
		||||
 | 
				    return system(command) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				num_tests_per_cluster = 50 | 
			
		||||
 | 
				#factor_tests_per_cluster = 0.2 | 
			
		||||
 | 
				num_ski_positions = 8 | 
			
		||||
 | 
				num_velocities = 5 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def input_to_action(char): | 
			
		||||
 | 
				    if char == "0": | 
			
		||||
 | 
				        return Action.NOOP | 
			
		||||
 | 
				    if char == "1": | 
			
		||||
 | 
				        return Action.RIGHT | 
			
		||||
 | 
				    if char == "2": | 
			
		||||
 | 
				        return Action.LEFT | 
			
		||||
 | 
				    if char == "3": | 
			
		||||
 | 
				        return "reset" | 
			
		||||
 | 
				    if char == "4": | 
			
		||||
 | 
				        return "set_x" | 
			
		||||
 | 
				    if char == "5": | 
			
		||||
 | 
				        return "set_vel" | 
			
		||||
 | 
				    if char in ["w", "a", "s", "d"]: | 
			
		||||
 | 
				        return char | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def saveObservations(observations, verdict, testDir): | 
			
		||||
 | 
				    testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}" | 
			
		||||
 | 
				    if len(observations) < 20: | 
			
		||||
 | 
				        logger.warn(f"Potentially spurious test case for {testDir}") | 
			
		||||
 | 
				        testDir = f"{testDir}_pot_spurious" | 
			
		||||
 | 
				    exec(f"mkdir {testDir}", verbose=False) | 
			
		||||
 | 
				    for i, obs in enumerate(observations): | 
			
		||||
 | 
				        img = Image.fromarray(obs) | 
			
		||||
 | 
				        img.save(f"{testDir}/{i:003}.png") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): | 
			
		||||
 | 
				    #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") | 
			
		||||
 | 
				    testDir = f"{x}_{y}_{ski_position}_{velocity}" | 
			
		||||
 | 
				    try: | 
			
		||||
 | 
				        for i, r in enumerate(ramDICT[y]): | 
			
		||||
 | 
				            ale.setRAM(i,r) | 
			
		||||
 | 
				        ski_position_setting = ski_position_counter[ski_position] | 
			
		||||
 | 
				        for i in range(0,ski_position_setting[1]): | 
			
		||||
 | 
				            ale.act(ski_position_setting[0]) | 
			
		||||
 | 
				            ale.setRAM(14,0) | 
			
		||||
 | 
				            ale.setRAM(25,x) | 
			
		||||
 | 
				        ale.setRAM(14,180) # TODO | 
			
		||||
 | 
				    except Exception as e: | 
			
		||||
 | 
				        print(e) | 
			
		||||
 | 
				        logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}") | 
			
		||||
 | 
				        return (Verdict.INCONCLUSIVE, 0) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    num_queries = 0 | 
			
		||||
 | 
				    all_obs = list() | 
			
		||||
 | 
				    speed_list = list() | 
			
		||||
 | 
				    resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) | 
			
		||||
 | 
				    for i in range(0,4): | 
			
		||||
 | 
				        all_obs.append(resized_obs) | 
			
		||||
 | 
				    for i in range(0,duration-4): | 
			
		||||
 | 
				        resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) | 
			
		||||
 | 
				        all_obs.append(resized_obs) | 
			
		||||
 | 
				        if i % 4 == 0: | 
			
		||||
 | 
				            stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) | 
			
		||||
 | 
				            action = nn_wrapper.query(stack_tensor) | 
			
		||||
 | 
				            num_queries += 1 | 
			
		||||
 | 
				            ale.act(input_to_action(str(action))) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            ale.act(input_to_action(str(action))) | 
			
		||||
 | 
				        speed_list.append(ale.getRAM()[14]) | 
			
		||||
 | 
				        if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: | 
			
		||||
 | 
				            #saveObservations(all_obs, Verdict.BAD, testDir) | 
			
		||||
 | 
				            return (Verdict.BAD, num_queries) | 
			
		||||
 | 
				    #saveObservations(all_obs, Verdict.GOOD, testDir) | 
			
		||||
 | 
				    return (Verdict.GOOD, num_queries) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def skiPositionFormulaList(name): | 
			
		||||
 | 
				    formulas = list() | 
			
		||||
 | 
				    for i in range(1, num_ski_positions+1): | 
			
		||||
 | 
				        formulas.append(f"\"{name}_{i}\"") | 
			
		||||
 | 
				    return createBalancedDisjunction(formulas) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def computeStateRanking(mdp_file, iteration): | 
			
		||||
 | 
				    logger.info("Computing state ranking") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    prop =  f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" | 
			
		||||
 | 
				    prop += 'Rmax=? [C <= 200]' | 
			
		||||
 | 
				    results = list() | 
			
		||||
 | 
				    try: | 
			
		||||
 | 
				        command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" | 
			
		||||
 | 
				        output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') | 
			
		||||
 | 
				        num_states = 0 | 
			
		||||
 | 
				        for line in output: | 
			
		||||
 | 
				            #print(line) | 
			
		||||
 | 
				            if "States:" in line: | 
			
		||||
 | 
				                num_states = int(line.split(" ")[-1]) | 
			
		||||
 | 
				            if "Result" in line and not len(results) >= 6: | 
			
		||||
 | 
				                range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) | 
			
		||||
 | 
				                if range_value: | 
			
		||||
 | 
				                    results.append(float(range_value.group(2))) | 
			
		||||
 | 
				                    results.append(float(range_value.group(3))) | 
			
		||||
 | 
				                else: | 
			
		||||
 | 
				                    value = re.search(r"(.*:)(.*)", line) | 
			
		||||
 | 
				                    results.append(float(value.group(2))) | 
			
		||||
 | 
				        exec(f"mv action_ranking action_ranking_{iteration:03}") | 
			
		||||
 | 
				    except subprocess.CalledProcessError as e: | 
			
		||||
 | 
				        # todo die gracefully if ranking is uniform | 
			
		||||
 | 
				        print(e.output) | 
			
		||||
 | 
				    logger.info(f"Computing state ranking - DONE: took {toc()} seconds") | 
			
		||||
 | 
				    return TestResult(*tuple(results),0,0,0,0,0,0,0), num_states | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def fillStateRanking(file_name, match=""): | 
			
		||||
 | 
				    logger.info(f"Parsing state ranking, {file_name}") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    state_ranking = dict() | 
			
		||||
 | 
				    try: | 
			
		||||
 | 
				        with open(file_name, "r") as f: | 
			
		||||
 | 
				            file_content = f.readlines() | 
			
		||||
 | 
				        for line in file_content: | 
			
		||||
 | 
				            if not "move=0" in line: continue | 
			
		||||
 | 
				            ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) | 
			
		||||
 | 
				            if ranking_value <= 0.1: | 
			
		||||
 | 
				                continue | 
			
		||||
 | 
				            stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) | 
			
		||||
 | 
				            choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) | 
			
		||||
 | 
				            choices = {key:float(value) for (key,value) in choices.items()} | 
			
		||||
 | 
				            state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2) | 
			
		||||
 | 
				            value = StateValue(ranking_value, choices) | 
			
		||||
 | 
				            state_ranking[state] = value | 
			
		||||
 | 
				        logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") | 
			
		||||
 | 
				        return state_ranking | 
			
		||||
 | 
				    except EnvironmentError: | 
			
		||||
 | 
				        print("Ranking file not available. Exiting.") | 
			
		||||
 | 
				        toc() | 
			
		||||
 | 
				        sys.exit(-1) | 
			
		||||
 | 
				    except: | 
			
		||||
 | 
				        toc() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def createDisjunction(formulas): | 
			
		||||
 | 
				    return " | ".join(formulas) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def statesFormulaTrimmed(states, name): | 
			
		||||
 | 
				    #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] | 
			
		||||
 | 
				    skiPositionGroup = defaultdict(list) | 
			
		||||
 | 
				    for item in states: | 
			
		||||
 | 
				        skiPositionGroup[item[2]].append(item) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    formulas = list() | 
			
		||||
 | 
				    for skiPosition, skiPos_group in skiPositionGroup.items(): | 
			
		||||
 | 
				        formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & " | 
			
		||||
 | 
				        #print(f"{name} ski_pos:{skiPosition}") | 
			
		||||
 | 
				        velocityGroup = defaultdict(list) | 
			
		||||
 | 
				        velocityFormulas = list() | 
			
		||||
 | 
				        for item in skiPos_group: | 
			
		||||
 | 
				            velocityGroup[item[3]].append(item) | 
			
		||||
 | 
				        for velocity, velocity_group in velocityGroup.items(): | 
			
		||||
 | 
				            #print(f"\tvel:{velocity}") | 
			
		||||
 | 
				            formulasPerSkiPosition = list() | 
			
		||||
 | 
				            yPosGroup = defaultdict(list) | 
			
		||||
 | 
				            yFormulas = list() | 
			
		||||
 | 
				            for item in velocity_group: | 
			
		||||
 | 
				                yPosGroup[item[1]].append(item) | 
			
		||||
 | 
				            for y, y_group in yPosGroup.items(): | 
			
		||||
 | 
				                #print(f"\t\ty:{y}") | 
			
		||||
 | 
				                sorted_y_group = sorted(y_group, key=lambda s: s[0]) | 
			
		||||
 | 
				                current_x_min = sorted_y_group[0][0] | 
			
		||||
 | 
				                current_x = sorted_y_group[0][0] | 
			
		||||
 | 
				                x_ranges = list() | 
			
		||||
 | 
				                for state in sorted_y_group[1:-1]: | 
			
		||||
 | 
				                    if state[0] - current_x == 1: | 
			
		||||
 | 
				                        current_x = state[0] | 
			
		||||
 | 
				                    else: | 
			
		||||
 | 
				                        x_ranges.append(f" ({current_x_min}<=x&x<={current_x})") | 
			
		||||
 | 
				                        current_x_min = state[0] | 
			
		||||
 | 
				                        current_x = state[0] | 
			
		||||
 | 
				                x_ranges.append(f" {current_x_min}<=x&x<={sorted_y_group[-1][0]}") | 
			
		||||
 | 
				                yFormulas.append(f" (y={y} & {createBalancedDisjunction(x_ranges)})") | 
			
		||||
 | 
				                #x_ranges.clear() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            #velocityFormulas.append(f"(velocity={velocity} & {createBalancedDisjunction(yFormulas)})") | 
			
		||||
 | 
				            velocityFormulas.append(f"({createBalancedDisjunction(yFormulas)})") | 
			
		||||
 | 
				            #yFormulas.clear() | 
			
		||||
 | 
				        formula += createBalancedDisjunction(velocityFormulas) + ");" | 
			
		||||
 | 
				        #velocityFormulas.clear() | 
			
		||||
 | 
				        formulas.append(formula) | 
			
		||||
 | 
				    for i in range(1, num_ski_positions+1): | 
			
		||||
 | 
				        if i in skiPositionGroup: | 
			
		||||
 | 
				            continue | 
			
		||||
 | 
				        formulas.append(f"formula {name}_{i} = false;") | 
			
		||||
 | 
				    return "\n".join(formulas) + "\n" | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list | 
			
		||||
 | 
				def pairwise(iterable): | 
			
		||||
 | 
				    "s -> (s0, s1), (s2, s3), (s4, s5), ..." | 
			
		||||
 | 
				    a = iter(iterable) | 
			
		||||
 | 
				    return zip(a, a) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def createBalancedDisjunction(formulas): | 
			
		||||
 | 
				    if len(formulas) == 0: | 
			
		||||
 | 
				        return "false" | 
			
		||||
 | 
				    while len(formulas) > 1: | 
			
		||||
 | 
				        formulas_tmp = [f"({f} | {g})"  for f,g in pairwise(formulas)] | 
			
		||||
 | 
				        if len(formulas) % 2 == 1: | 
			
		||||
 | 
				            formulas_tmp.append(formulas[-1]) | 
			
		||||
 | 
				        formulas = formulas_tmp | 
			
		||||
 | 
				    return " ".join(formulas) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def updatePrismFile(newFile, iteration, safeStates, unsafeStates): | 
			
		||||
 | 
				    logger.info("Creating next prism file") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    initFile = f"{newFile}_no_formulas.prism" | 
			
		||||
 | 
				    newFile = f"{newFile}_{iteration:03}.prism" | 
			
		||||
 | 
				    exec(f"cp {initFile} {newFile}", verbose=False) | 
			
		||||
 | 
				    with open(newFile, "a") as prism: | 
			
		||||
 | 
				        prism.write(statesFormulaTrimmed(safeStates, "Safe")) | 
			
		||||
 | 
				        prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe")) | 
			
		||||
 | 
				        for i in range(1,num_ski_positions+1): | 
			
		||||
 | 
				            prism.write(f"label \"Safe_{i}\" = Safe_{i};\n") | 
			
		||||
 | 
				            prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    logger.info(f"Creating next prism file - DONE: took {toc()} seconds") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				ale = ALEInterface() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				#if SDL_SUPPORT: | 
			
		||||
 | 
				#    ale.setBool("sound", True) | 
			
		||||
 | 
				#    ale.setBool("display_screen", True) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# Load the ROM file | 
			
		||||
 | 
				ale.loadROM(rom_file) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				with open('all_positions_v2.pickle', 'rb') as handle: | 
			
		||||
 | 
				    ramDICT = pickle.load(handle) | 
			
		||||
 | 
				y_ram_setting = 60 | 
			
		||||
 | 
				x = 70 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				nn_wrapper = SampleFactoryNNQueryWrapper() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				experiment_id = int(time.time()) | 
			
		||||
 | 
				init_mdp = "velocity_safety" | 
			
		||||
 | 
				exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				imagesDir = f"images/testing_{experiment_id}" | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0, markerSize=1, drawCircle=False): | 
			
		||||
 | 
				    #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} | 
			
		||||
 | 
				    markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)} | 
			
		||||
 | 
				    images = dict() | 
			
		||||
 | 
				    mergedImages = dict() | 
			
		||||
 | 
				    for ski_position in range(1, num_ski_positions + 1): | 
			
		||||
 | 
				        for velocity in range(0,num_velocities): | 
			
		||||
 | 
				            images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png") | 
			
		||||
 | 
				        mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png") | 
			
		||||
 | 
				    for state in states: | 
			
		||||
 | 
				        s = state[0] | 
			
		||||
 | 
				        marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)] | 
			
		||||
 | 
				        markerList[(s.ski_position, s.velocity)].append(marker) | 
			
		||||
 | 
				    for (pos, vel), marker in markerList.items(): | 
			
		||||
 | 
				        if len(marker) == 0: continue | 
			
		||||
 | 
				        if drawCircle: | 
			
		||||
 | 
				            for m in marker: | 
			
		||||
 | 
				                images[(pos,vel)] = cv2.circle(images[(pos,vel)], m[2], 1, m[0], thickness=-1) | 
			
		||||
 | 
				                mergedImages[pos] = cv2.circle(mergedImages[pos], m[2], 1, m[0], thickness=-1) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            for m in marker: | 
			
		||||
 | 
				                images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED) | 
			
		||||
 | 
				                mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED) | 
			
		||||
 | 
				    for (ski_position, velocity), image in images.items(): | 
			
		||||
 | 
				        cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image) | 
			
		||||
 | 
				    for ski_position, image in mergedImages.items(): | 
			
		||||
 | 
				        cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def concatImages(prefix, iteration): | 
			
		||||
 | 
				    logger.info(f"Concatenating images") | 
			
		||||
 | 
				    images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)] | 
			
		||||
 | 
				    mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)] | 
			
		||||
 | 
				    for vel in range(0, num_velocities): | 
			
		||||
 | 
				        for pos in range(1, num_ski_positions + 1): | 
			
		||||
 | 
				            command =  f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png " | 
			
		||||
 | 
				            command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' " | 
			
		||||
 | 
				            command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" | 
			
		||||
 | 
				            exec(command, verbose=False) | 
			
		||||
 | 
				    exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}.png", verbose=False) | 
			
		||||
 | 
				    exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}_merged.png", verbose=False) | 
			
		||||
 | 
				    #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False) | 
			
		||||
 | 
				    logger.info(f"Concatenating images - DONE") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				    Useful to draw a set of states, e.g. a single cluster | 
			
		||||
 | 
				    TODO | 
			
		||||
 | 
				    markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} | 
			
		||||
 | 
				    logger.info(f"Drawing {len(states)} states onto {target}") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    for state in states: | 
			
		||||
 | 
				        s = state[0] | 
			
		||||
 | 
				        marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" | 
			
		||||
 | 
				        markerList[s.ski_position].append(marker) | 
			
		||||
 | 
				    for pos, marker in markerList.items(): | 
			
		||||
 | 
				        command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png" | 
			
		||||
 | 
				        exec(command, verbose=False) | 
			
		||||
 | 
				    exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) | 
			
		||||
 | 
				    logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") | 
			
		||||
 | 
				    """ | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): | 
			
		||||
 | 
				    logger.info(f"Drawing {len(clusterDict)} clusters") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    for _, clusterStates in clusterDict.items(): | 
			
		||||
 | 
				        color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256))) | 
			
		||||
 | 
				        color = (int(color[0]), int(color[1]), int(color[2])) | 
			
		||||
 | 
				        drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) | 
			
		||||
 | 
				    concatImages(target, iteration) | 
			
		||||
 | 
				    logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def drawResult(clusterDict, target, iteration, drawnCluster=set()): | 
			
		||||
 | 
				    logger.info(f"Drawing {len(clusterDict)} results") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    for id, (clusterStates, result) in clusterDict.items(): | 
			
		||||
 | 
				        if id in drawnCluster: continue | 
			
		||||
 | 
				        # opencv wants BGR | 
			
		||||
 | 
				        color = (100,100,100) | 
			
		||||
 | 
				        if result == Verdict.GOOD: | 
			
		||||
 | 
				            color = (0,200,0) | 
			
		||||
 | 
				        elif result == Verdict.BAD: | 
			
		||||
 | 
				            color = (0,0,200) | 
			
		||||
 | 
				        drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) | 
			
		||||
 | 
				    logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def _init_logger(): | 
			
		||||
 | 
				    logger = logging.getLogger('main') | 
			
		||||
 | 
				    logger.setLevel(logging.INFO) | 
			
		||||
 | 
				    handler = logging.StreamHandler(sys.stdout) | 
			
		||||
 | 
				    formatter = logging.Formatter(       '[%(levelname)s] %(module)s - %(message)s') | 
			
		||||
 | 
				    handler.setFormatter(formatter) | 
			
		||||
 | 
				    logger.addHandler(handler) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def clusterImportantStates(ranking, iteration): | 
			
		||||
 | 
				    logger.info(f"Starting to cluster {len(ranking)} states into clusters") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking] | 
			
		||||
 | 
				    #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking] | 
			
		||||
 | 
				    kmeans = KMeans(len(states) // 15, random_state=0, n_init="auto").fit(states) | 
			
		||||
 | 
				    #dbscan = DBSCAN(eps=5).fit(states) | 
			
		||||
 | 
				    #labels = dbscan.labels_ | 
			
		||||
 | 
				    labels = kmeans.labels_ | 
			
		||||
 | 
				    n_clusters = len(set(labels)) - (1 if -1 in labels else 0) | 
			
		||||
 | 
				    logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster") | 
			
		||||
 | 
				    clusterDict = {i : list() for i in range(0,n_clusters)} | 
			
		||||
 | 
				    strayStates = list() | 
			
		||||
 | 
				    for i, state in enumerate(ranking): | 
			
		||||
 | 
				        if labels[i] == -1: | 
			
		||||
 | 
				            clusterDict[n_clusters + len(strayStates) + 1] = list() | 
			
		||||
 | 
				            clusterDict[n_clusters + len(strayStates) + 1].append(state) | 
			
		||||
 | 
				            strayStates.append(state) | 
			
		||||
 | 
				            continue | 
			
		||||
 | 
				        clusterDict[labels[i]].append(state) | 
			
		||||
 | 
				    if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1") | 
			
		||||
 | 
				    #drawClusters(clusterDict, f"clusters", iteration) | 
			
		||||
 | 
				    return clusterDict | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def run_experiment(factor_tests_per_cluster): | 
			
		||||
 | 
				    logger.info("Starting") | 
			
		||||
 | 
				    num_queries = 0 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    source = "images/1_full_scaled_down.png" | 
			
		||||
 | 
				    for ski_position in range(1, num_ski_positions + 1): | 
			
		||||
 | 
				        for velocity in range(0,num_velocities): | 
			
		||||
 | 
				            exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False) | 
			
		||||
 | 
				            exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False) | 
			
		||||
 | 
				        exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False) | 
			
		||||
 | 
				        exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    goodVerdicts = 0 | 
			
		||||
 | 
				    badVerdicts = 0 | 
			
		||||
 | 
				    goodVerdictTestCases = list() | 
			
		||||
 | 
				    badVerdictTestCases = list() | 
			
		||||
 | 
				    safeClusters = 0 | 
			
		||||
 | 
				    unsafeClusters = 0 | 
			
		||||
 | 
				    safeStates = set() | 
			
		||||
 | 
				    unsafeStates = set() | 
			
		||||
 | 
				    iteration = 0 | 
			
		||||
 | 
				    results = list() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    eps = 0.1 | 
			
		||||
 | 
				    updatePrismFile(init_mdp, iteration, set(), set()) | 
			
		||||
 | 
				    #modelCheckingResult, numStates = TestResult(0,0,0,0,0,0,0,0,0,0,0,0,0), 10 | 
			
		||||
 | 
				    modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_000.prism", iteration) | 
			
		||||
 | 
				    results.append(modelCheckingResult) | 
			
		||||
 | 
				    ranking = fillStateRanking(f"action_ranking_000") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) | 
			
		||||
 | 
				    try: | 
			
		||||
 | 
				        clusters = clusterImportantStates(sorted_ranking, iteration) | 
			
		||||
 | 
				    except Exception as e: | 
			
		||||
 | 
				        print(e) | 
			
		||||
 | 
				        sys.exit(-1) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    clusterResult = dict() | 
			
		||||
 | 
				    logger.info(f"Running tests") | 
			
		||||
 | 
				    tic() | 
			
		||||
 | 
				    num_cluster_tested = 0 | 
			
		||||
 | 
				    iteration = 0 | 
			
		||||
 | 
				    drawnCluster = set() | 
			
		||||
 | 
				    for id, cluster in clusters.items(): | 
			
		||||
 | 
				        num_tests = int(factor_tests_per_cluster * len(cluster)) | 
			
		||||
 | 
				        if num_tests == 0: num_tests = 1 | 
			
		||||
 | 
				        logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}") | 
			
		||||
 | 
				        randomStates = np.random.choice(len(cluster), num_tests, replace=False) | 
			
		||||
 | 
				        randomStates = [cluster[i] for i in randomStates] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        verdictGood = True | 
			
		||||
 | 
				        for state in randomStates: | 
			
		||||
 | 
				            x = state[0].x | 
			
		||||
 | 
				            y = state[0].y | 
			
		||||
 | 
				            ski_pos = state[0].ski_position | 
			
		||||
 | 
				            velocity = state[0].velocity | 
			
		||||
 | 
				            result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50) | 
			
		||||
 | 
				            num_queries += num_queries_this_test_case | 
			
		||||
 | 
				            if result == Verdict.BAD: | 
			
		||||
 | 
				                clusterResult[id] = (cluster, Verdict.BAD) | 
			
		||||
 | 
				                verdictGood = False | 
			
		||||
 | 
				                unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) | 
			
		||||
 | 
				                badVerdicts += 1 | 
			
		||||
 | 
				                badVerdictTestCases.append(state) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            elif result == Verdict.GOOD: | 
			
		||||
 | 
				                goodVerdicts += 1 | 
			
		||||
 | 
				                goodVerdictTestCases.append(state) | 
			
		||||
 | 
				        if verdictGood: | 
			
		||||
 | 
				            clusterResult[id] = (cluster, Verdict.GOOD) | 
			
		||||
 | 
				            safeClusters += 1 | 
			
		||||
 | 
				            safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            unsafeClusters += 1 | 
			
		||||
 | 
				        results[-1].safe_states = len(safeStates) | 
			
		||||
 | 
				        results[-1].unsafe_states = len(unsafeStates) | 
			
		||||
 | 
				        results[-1].policy_queries = num_queries | 
			
		||||
 | 
				        results[-1].safe_cluster = safeClusters | 
			
		||||
 | 
				        results[-1].unsafe_cluster = unsafeClusters | 
			
		||||
 | 
				        results[-1].good_verdicts = goodVerdicts | 
			
		||||
 | 
				        results[-1].bad_verdicts = badVerdicts | 
			
		||||
 | 
				        num_cluster_tested += 1 | 
			
		||||
 | 
				        if num_cluster_tested % (len(clusters)//20) == 0: | 
			
		||||
 | 
				            iteration += 1 | 
			
		||||
 | 
				            logger.info(f"Tested Cluster: {num_cluster_tested:03}\tSafe Cluster States : {len(safeStates)}({safeClusters}/{len(clusters)})\tUnsafe Cluster States:{len(unsafeStates)}({unsafeClusters}/{len(clusters)})\tGood Test Cases:{goodVerdicts}\tFailing Test Cases:{badVerdicts}\t{len(safeStates)/len(unsafeStates)} - {goodVerdicts/badVerdicts}") | 
			
		||||
 | 
				            drawResult(clusterResult, "result", iteration, drawnCluster) | 
			
		||||
 | 
				            drawOntoSkiPosImage(goodVerdictTestCases, (10,255,50), "result", alpha_factor=0.7, markerSize=0, drawCircle=True) | 
			
		||||
 | 
				            drawOntoSkiPosImage(badVerdictTestCases, (0,0,0), "result", alpha_factor=0.7, markerSize=0, drawCircle=True) | 
			
		||||
 | 
				            concatImages("result", iteration) | 
			
		||||
 | 
				            drawnCluster.update(clusterResult.keys()) | 
			
		||||
 | 
				            #updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) | 
			
		||||
 | 
				            #modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) | 
			
		||||
 | 
				            results.append(deepcopy(modelCheckingResult)) | 
			
		||||
 | 
				            logger.info(f"Model Checking Result: {modelCheckingResult}") | 
			
		||||
 | 
				            # Account for self-loop states after first iteration | 
			
		||||
 | 
				            if iteration > 0: | 
			
		||||
 | 
				                results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafe_states + 0.0*results[-2].safe_states) | 
			
		||||
 | 
				                results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafe_states + 1.0*results[-2].safe_states) | 
			
		||||
 | 
				            print(TestResult.csv_header()) | 
			
		||||
 | 
				            for result in results[:-1]: | 
			
		||||
 | 
				                print(result.csv()) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    with open(f"data_new_method_{factor_tests_per_cluster}", "w") as f: | 
			
		||||
 | 
				        f.write(TestResult.csv_header() + "\n") | 
			
		||||
 | 
				        for result in results[:-1]: | 
			
		||||
 | 
				            f.write(result.csv() + "\n") | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				_init_logger() | 
			
		||||
 | 
				logger = logging.getLogger('main') | 
			
		||||
 | 
				if __name__ == '__main__': | 
			
		||||
 | 
				    for factor_tests_per_cluster in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]: | 
			
		||||
 | 
				        run_experiment(factor_tests_per_cluster) | 
			
		||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue