You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
605 lines
26 KiB
605 lines
26 KiB
import sys
|
|
import operator
|
|
from copy import deepcopy
|
|
from os import listdir, system
|
|
import subprocess
|
|
import re
|
|
from collections import defaultdict
|
|
|
|
from random import randrange
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action
|
|
from PIL import Image
|
|
from matplotlib import pyplot as plt
|
|
import cv2
|
|
import pickle
|
|
import queue
|
|
from dataclasses import dataclass, field
|
|
|
|
from sklearn.cluster import KMeans, DBSCAN
|
|
|
|
from enum import Enum
|
|
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
#import readchar
|
|
|
|
from sample_factory.algo.utils.tensor_dict import TensorDict
|
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
|
|
|
|
import time
|
|
|
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
|
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
|
|
|
|
def tic():
|
|
import time
|
|
global startTime_for_tictoc
|
|
startTime_for_tictoc = time.time()
|
|
|
|
def toc():
|
|
import time
|
|
if 'startTime_for_tictoc' in globals():
|
|
return time.time() - startTime_for_tictoc
|
|
|
|
class Verdict(Enum):
|
|
INCONCLUSIVE = 1
|
|
GOOD = 2
|
|
BAD = 3
|
|
|
|
verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
|
|
|
|
def convert(tuples):
|
|
return dict(tuples)
|
|
|
|
@dataclass(frozen=True)
|
|
class State:
|
|
x: int
|
|
y: int
|
|
ski_position: int
|
|
velocity: int
|
|
def default_value():
|
|
return {'action' : None, 'choiceValue' : None}
|
|
|
|
@dataclass(frozen=True)
|
|
class StateValue:
|
|
ranking: float
|
|
choices: dict = field(default_factory=default_value)
|
|
|
|
@dataclass(frozen=False)
|
|
class TestResult:
|
|
init_check_pes_min: float
|
|
init_check_pes_max: float
|
|
init_check_pes_avg: float
|
|
init_check_opt_min: float
|
|
init_check_opt_max: float
|
|
init_check_opt_avg: float
|
|
safe_states: int
|
|
unsafe_states: int
|
|
safe_cluster: int
|
|
unsafe_cluster: int
|
|
good_verdicts: int
|
|
bad_verdicts: int
|
|
policy_queries: int
|
|
def __str__(self):
|
|
return f"""Test Result:
|
|
init_check_pes_min: {self.init_check_pes_min}
|
|
init_check_pes_max: {self.init_check_pes_max}
|
|
init_check_pes_avg: {self.init_check_pes_avg}
|
|
init_check_opt_min: {self.init_check_opt_min}
|
|
init_check_opt_max: {self.init_check_opt_max}
|
|
init_check_opt_avg: {self.init_check_opt_avg}
|
|
"""
|
|
@staticmethod
|
|
def csv_header(ws=" "):
|
|
string = f"pesmin{ws}pesmax{ws}pesavg{ws}"
|
|
string += f"optmin{ws}optmax{ws}optavg{ws}"
|
|
string += f"sState{ws}uState{ws}"
|
|
string += f"sClust{ws}uClust{ws}"
|
|
string += f"gVerd{ws}bVerd{ws}queries"
|
|
return string
|
|
|
|
def csv(self):
|
|
ws = " "
|
|
string = f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}"
|
|
string += f"{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}"
|
|
ws = "\t"
|
|
string += f"{self.safe_states}{ws}{self.unsafe_states}{ws}"
|
|
string += f"{self.safe_cluster}{ws}{self.unsafe_cluster}{ws}"
|
|
string += f"{self.good_verdicts}{ws}{self.bad_verdicts}{ws}{self.policy_queries}"
|
|
return string
|
|
|
|
|
|
def exec(command,verbose=True):
|
|
if verbose: print(f"Executing {command}")
|
|
system(f"echo {command} >> list_of_exec")
|
|
return system(command)
|
|
|
|
num_tests_per_cluster = 50
|
|
#factor_tests_per_cluster = 0.2
|
|
num_ski_positions = 8
|
|
num_velocities = 5
|
|
|
|
def input_to_action(char):
|
|
if char == "0":
|
|
return Action.NOOP
|
|
if char == "1":
|
|
return Action.RIGHT
|
|
if char == "2":
|
|
return Action.LEFT
|
|
if char == "3":
|
|
return "reset"
|
|
if char == "4":
|
|
return "set_x"
|
|
if char == "5":
|
|
return "set_vel"
|
|
if char in ["w", "a", "s", "d"]:
|
|
return char
|
|
|
|
def saveObservations(observations, verdict, testDir):
|
|
testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
|
|
if len(observations) < 20:
|
|
logger.warn(f"Potentially spurious test case for {testDir}")
|
|
testDir = f"{testDir}_pot_spurious"
|
|
exec(f"mkdir {testDir}", verbose=False)
|
|
for i, obs in enumerate(observations):
|
|
img = Image.fromarray(obs)
|
|
img.save(f"{testDir}/{i:003}.png")
|
|
|
|
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
|
|
|
|
def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
|
|
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
|
|
testDir = f"{x}_{y}_{ski_position}_{velocity}"
|
|
try:
|
|
for i, r in enumerate(ramDICT[y]):
|
|
ale.setRAM(i,r)
|
|
ski_position_setting = ski_position_counter[ski_position]
|
|
for i in range(0,ski_position_setting[1]):
|
|
ale.act(ski_position_setting[0])
|
|
ale.setRAM(14,0)
|
|
ale.setRAM(25,x)
|
|
ale.setRAM(14,180) # TODO
|
|
except Exception as e:
|
|
print(e)
|
|
logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
|
|
return (Verdict.INCONCLUSIVE, 0)
|
|
|
|
num_queries = 0
|
|
all_obs = list()
|
|
speed_list = list()
|
|
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
|
|
for i in range(0,4):
|
|
all_obs.append(resized_obs)
|
|
for i in range(0,duration-4):
|
|
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
|
|
all_obs.append(resized_obs)
|
|
if i % 4 == 0:
|
|
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
|
|
action = nn_wrapper.query(stack_tensor)
|
|
num_queries += 1
|
|
ale.act(input_to_action(str(action)))
|
|
else:
|
|
ale.act(input_to_action(str(action)))
|
|
speed_list.append(ale.getRAM()[14])
|
|
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
|
|
#saveObservations(all_obs, Verdict.BAD, testDir)
|
|
return (Verdict.BAD, num_queries)
|
|
#saveObservations(all_obs, Verdict.GOOD, testDir)
|
|
return (Verdict.GOOD, num_queries)
|
|
|
|
def skiPositionFormulaList(name):
|
|
formulas = list()
|
|
for i in range(1, num_ski_positions+1):
|
|
formulas.append(f"\"{name}_{i}\"")
|
|
return createBalancedDisjunction(formulas)
|
|
|
|
|
|
def computeStateRanking(mdp_file, iteration):
|
|
logger.info("Computing state ranking")
|
|
tic()
|
|
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
|
|
prop += 'Rmax=? [C <= 200]'
|
|
results = list()
|
|
try:
|
|
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
|
|
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
|
|
num_states = 0
|
|
for line in output:
|
|
#print(line)
|
|
if "States:" in line:
|
|
num_states = int(line.split(" ")[-1])
|
|
if "Result" in line and not len(results) >= 6:
|
|
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
|
|
if range_value:
|
|
results.append(float(range_value.group(2)))
|
|
results.append(float(range_value.group(3)))
|
|
else:
|
|
value = re.search(r"(.*:)(.*)", line)
|
|
results.append(float(value.group(2)))
|
|
exec(f"mv action_ranking action_ranking_{iteration:03}")
|
|
except subprocess.CalledProcessError as e:
|
|
# todo die gracefully if ranking is uniform
|
|
print(e.output)
|
|
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
|
|
return TestResult(*tuple(results),0,0,0,0,0,0,0), num_states
|
|
|
|
def fillStateRanking(file_name, match=""):
|
|
logger.info(f"Parsing state ranking, {file_name}")
|
|
tic()
|
|
state_ranking = dict()
|
|
try:
|
|
with open(file_name, "r") as f:
|
|
file_content = f.readlines()
|
|
for line in file_content:
|
|
if not "move=0" in line: continue
|
|
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
|
|
if ranking_value <= 0.1:
|
|
continue
|
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
|
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
|
|
choices = {key:float(value) for (key,value) in choices.items()}
|
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
|
|
value = StateValue(ranking_value, choices)
|
|
state_ranking[state] = value
|
|
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
|
|
return state_ranking
|
|
except EnvironmentError:
|
|
print("Ranking file not available. Exiting.")
|
|
toc()
|
|
sys.exit(-1)
|
|
except:
|
|
toc()
|
|
|
|
def createDisjunction(formulas):
|
|
return " | ".join(formulas)
|
|
|
|
def statesFormulaTrimmed(states, name):
|
|
#states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
|
|
skiPositionGroup = defaultdict(list)
|
|
for item in states:
|
|
skiPositionGroup[item[2]].append(item)
|
|
|
|
formulas = list()
|
|
for skiPosition, skiPos_group in skiPositionGroup.items():
|
|
formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & "
|
|
#print(f"{name} ski_pos:{skiPosition}")
|
|
velocityGroup = defaultdict(list)
|
|
velocityFormulas = list()
|
|
for item in skiPos_group:
|
|
velocityGroup[item[3]].append(item)
|
|
for velocity, velocity_group in velocityGroup.items():
|
|
#print(f"\tvel:{velocity}")
|
|
formulasPerSkiPosition = list()
|
|
yPosGroup = defaultdict(list)
|
|
yFormulas = list()
|
|
for item in velocity_group:
|
|
yPosGroup[item[1]].append(item)
|
|
for y, y_group in yPosGroup.items():
|
|
#print(f"\t\ty:{y}")
|
|
sorted_y_group = sorted(y_group, key=lambda s: s[0])
|
|
current_x_min = sorted_y_group[0][0]
|
|
current_x = sorted_y_group[0][0]
|
|
x_ranges = list()
|
|
for state in sorted_y_group[1:-1]:
|
|
if state[0] - current_x == 1:
|
|
current_x = state[0]
|
|
else:
|
|
x_ranges.append(f" ({current_x_min}<=x&x<={current_x})")
|
|
current_x_min = state[0]
|
|
current_x = state[0]
|
|
x_ranges.append(f" {current_x_min}<=x&x<={sorted_y_group[-1][0]}")
|
|
yFormulas.append(f" (y={y} & {createBalancedDisjunction(x_ranges)})")
|
|
#x_ranges.clear()
|
|
|
|
#velocityFormulas.append(f"(velocity={velocity} & {createBalancedDisjunction(yFormulas)})")
|
|
velocityFormulas.append(f"({createBalancedDisjunction(yFormulas)})")
|
|
#yFormulas.clear()
|
|
formula += createBalancedDisjunction(velocityFormulas) + ");"
|
|
#velocityFormulas.clear()
|
|
formulas.append(formula)
|
|
for i in range(1, num_ski_positions+1):
|
|
if i in skiPositionGroup:
|
|
continue
|
|
formulas.append(f"formula {name}_{i} = false;")
|
|
return "\n".join(formulas) + "\n"
|
|
|
|
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
|
|
def pairwise(iterable):
|
|
"s -> (s0, s1), (s2, s3), (s4, s5), ..."
|
|
a = iter(iterable)
|
|
return zip(a, a)
|
|
|
|
def createBalancedDisjunction(formulas):
|
|
if len(formulas) == 0:
|
|
return "false"
|
|
while len(formulas) > 1:
|
|
formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
|
|
if len(formulas) % 2 == 1:
|
|
formulas_tmp.append(formulas[-1])
|
|
formulas = formulas_tmp
|
|
return " ".join(formulas)
|
|
|
|
def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
|
|
logger.info("Creating next prism file")
|
|
tic()
|
|
initFile = f"{newFile}_no_formulas.prism"
|
|
newFile = f"{newFile}_{iteration:03}.prism"
|
|
exec(f"cp {initFile} {newFile}", verbose=False)
|
|
with open(newFile, "a") as prism:
|
|
prism.write(statesFormulaTrimmed(safeStates, "Safe"))
|
|
prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe"))
|
|
for i in range(1,num_ski_positions+1):
|
|
prism.write(f"label \"Safe_{i}\" = Safe_{i};\n")
|
|
prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n")
|
|
|
|
logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
|
|
|
|
|
|
ale = ALEInterface()
|
|
|
|
|
|
#if SDL_SUPPORT:
|
|
# ale.setBool("sound", True)
|
|
# ale.setBool("display_screen", True)
|
|
|
|
# Load the ROM file
|
|
ale.loadROM(rom_file)
|
|
|
|
with open('all_positions_v2.pickle', 'rb') as handle:
|
|
ramDICT = pickle.load(handle)
|
|
y_ram_setting = 60
|
|
x = 70
|
|
|
|
|
|
nn_wrapper = SampleFactoryNNQueryWrapper()
|
|
|
|
experiment_id = int(time.time())
|
|
init_mdp = "velocity_safety"
|
|
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
|
|
|
|
|
|
imagesDir = f"images/testing_{experiment_id}"
|
|
|
|
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0, markerSize=1, drawCircle=False):
|
|
#markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
|
|
markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
|
|
images = dict()
|
|
mergedImages = dict()
|
|
for ski_position in range(1, num_ski_positions + 1):
|
|
for velocity in range(0,num_velocities):
|
|
images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
|
|
mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
|
|
for state in states:
|
|
s = state[0]
|
|
marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
|
|
markerList[(s.ski_position, s.velocity)].append(marker)
|
|
for (pos, vel), marker in markerList.items():
|
|
if len(marker) == 0: continue
|
|
if drawCircle:
|
|
for m in marker:
|
|
images[(pos,vel)] = cv2.circle(images[(pos,vel)], m[2], 1, m[0], thickness=-1)
|
|
mergedImages[pos] = cv2.circle(mergedImages[pos], m[2], 1, m[0], thickness=-1)
|
|
else:
|
|
for m in marker:
|
|
images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
|
|
mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
|
|
for (ski_position, velocity), image in images.items():
|
|
cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
|
|
for ski_position, image in mergedImages.items():
|
|
cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
|
|
|
|
|
|
def concatImages(prefix, iteration):
|
|
logger.info(f"Concatenating images")
|
|
images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
|
|
mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
|
|
for vel in range(0, num_velocities):
|
|
for pos in range(1, num_ski_positions + 1):
|
|
command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
|
|
command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
|
|
command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
|
|
exec(command, verbose=False)
|
|
exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}.png", verbose=False)
|
|
exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}_merged.png", verbose=False)
|
|
#exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
|
|
logger.info(f"Concatenating images - DONE")
|
|
|
|
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
|
|
"""
|
|
Useful to draw a set of states, e.g. a single cluster
|
|
TODO
|
|
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
|
|
logger.info(f"Drawing {len(states)} states onto {target}")
|
|
tic()
|
|
for state in states:
|
|
s = state[0]
|
|
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
|
|
markerList[s.ski_position].append(marker)
|
|
for pos, marker in markerList.items():
|
|
command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
|
|
exec(command, verbose=False)
|
|
exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
|
|
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
|
|
"""
|
|
|
|
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
|
|
logger.info(f"Drawing {len(clusterDict)} clusters")
|
|
tic()
|
|
for _, clusterStates in clusterDict.items():
|
|
color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
|
|
color = (int(color[0]), int(color[1]), int(color[2]))
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
|
|
concatImages(target, iteration)
|
|
logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
|
|
|
|
def drawResult(clusterDict, target, iteration, drawnCluster=set()):
|
|
logger.info(f"Drawing {len(clusterDict)} results")
|
|
tic()
|
|
for id, (clusterStates, result) in clusterDict.items():
|
|
if id in drawnCluster: continue
|
|
# opencv wants BGR
|
|
color = (100,100,100)
|
|
if result == Verdict.GOOD:
|
|
color = (0,200,0)
|
|
elif result == Verdict.BAD:
|
|
color = (0,0,200)
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
|
|
logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
|
|
|
|
def _init_logger():
|
|
logger = logging.getLogger('main')
|
|
logger.setLevel(logging.INFO)
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
def clusterImportantStates(ranking, iteration):
|
|
logger.info(f"Starting to cluster {len(ranking)} states into clusters")
|
|
tic()
|
|
states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
|
|
#states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
|
|
kmeans = KMeans(len(states) // 15, random_state=0, n_init="auto").fit(states)
|
|
#dbscan = DBSCAN(eps=5).fit(states)
|
|
#labels = dbscan.labels_
|
|
labels = kmeans.labels_
|
|
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
|
|
logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
|
|
clusterDict = {i : list() for i in range(0,n_clusters)}
|
|
strayStates = list()
|
|
for i, state in enumerate(ranking):
|
|
if labels[i] == -1:
|
|
clusterDict[n_clusters + len(strayStates) + 1] = list()
|
|
clusterDict[n_clusters + len(strayStates) + 1].append(state)
|
|
strayStates.append(state)
|
|
continue
|
|
clusterDict[labels[i]].append(state)
|
|
if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
|
|
#drawClusters(clusterDict, f"clusters", iteration)
|
|
return clusterDict
|
|
|
|
|
|
def run_experiment(factor_tests_per_cluster):
|
|
logger.info("Starting")
|
|
num_queries = 0
|
|
|
|
source = "images/1_full_scaled_down.png"
|
|
for ski_position in range(1, num_ski_positions + 1):
|
|
for velocity in range(0,num_velocities):
|
|
exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
|
|
exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
|
|
exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
|
|
exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
|
|
|
|
goodVerdicts = 0
|
|
badVerdicts = 0
|
|
goodVerdictTestCases = list()
|
|
badVerdictTestCases = list()
|
|
safeClusters = 0
|
|
unsafeClusters = 0
|
|
safeStates = set()
|
|
unsafeStates = set()
|
|
iteration = 0
|
|
results = list()
|
|
|
|
eps = 0.1
|
|
updatePrismFile(init_mdp, iteration, set(), set())
|
|
#modelCheckingResult, numStates = TestResult(0,0,0,0,0,0,0,0,0,0,0,0,0), 10
|
|
modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_000.prism", iteration)
|
|
results.append(modelCheckingResult)
|
|
ranking = fillStateRanking(f"action_ranking_000")
|
|
|
|
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
|
|
try:
|
|
clusters = clusterImportantStates(sorted_ranking, iteration)
|
|
except Exception as e:
|
|
print(e)
|
|
sys.exit(-1)
|
|
|
|
clusterResult = dict()
|
|
logger.info(f"Running tests")
|
|
tic()
|
|
num_cluster_tested = 0
|
|
iteration = 0
|
|
drawnCluster = set()
|
|
for id, cluster in clusters.items():
|
|
num_tests = int(factor_tests_per_cluster * len(cluster))
|
|
if num_tests == 0: num_tests = 1
|
|
logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
|
|
randomStates = np.random.choice(len(cluster), num_tests, replace=False)
|
|
randomStates = [cluster[i] for i in randomStates]
|
|
|
|
verdictGood = True
|
|
for state in randomStates:
|
|
x = state[0].x
|
|
y = state[0].y
|
|
ski_pos = state[0].ski_position
|
|
velocity = state[0].velocity
|
|
result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
|
|
num_queries += num_queries_this_test_case
|
|
if result == Verdict.BAD:
|
|
clusterResult[id] = (cluster, Verdict.BAD)
|
|
verdictGood = False
|
|
unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
|
|
badVerdicts += 1
|
|
badVerdictTestCases.append(state)
|
|
|
|
elif result == Verdict.GOOD:
|
|
goodVerdicts += 1
|
|
goodVerdictTestCases.append(state)
|
|
if verdictGood:
|
|
clusterResult[id] = (cluster, Verdict.GOOD)
|
|
safeClusters += 1
|
|
safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
|
|
else:
|
|
unsafeClusters += 1
|
|
results[-1].safe_states = len(safeStates)
|
|
results[-1].unsafe_states = len(unsafeStates)
|
|
results[-1].policy_queries = num_queries
|
|
results[-1].safe_cluster = safeClusters
|
|
results[-1].unsafe_cluster = unsafeClusters
|
|
results[-1].good_verdicts = goodVerdicts
|
|
results[-1].bad_verdicts = badVerdicts
|
|
num_cluster_tested += 1
|
|
if num_cluster_tested % (len(clusters)//20) == 0:
|
|
iteration += 1
|
|
logger.info(f"Tested Cluster: {num_cluster_tested:03}\tSafe Cluster States : {len(safeStates)}({safeClusters}/{len(clusters)})\tUnsafe Cluster States:{len(unsafeStates)}({unsafeClusters}/{len(clusters)})\tGood Test Cases:{goodVerdicts}\tFailing Test Cases:{badVerdicts}\t{len(safeStates)/len(unsafeStates)} - {goodVerdicts/badVerdicts}")
|
|
drawResult(clusterResult, "result", iteration, drawnCluster)
|
|
drawOntoSkiPosImage(goodVerdictTestCases, (10,255,50), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
|
|
drawOntoSkiPosImage(badVerdictTestCases, (0,0,0), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
|
|
concatImages("result", iteration)
|
|
drawnCluster.update(clusterResult.keys())
|
|
#updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
|
|
#modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
|
|
results.append(deepcopy(modelCheckingResult))
|
|
logger.info(f"Model Checking Result: {modelCheckingResult}")
|
|
# Account for self-loop states after first iteration
|
|
if iteration > 0:
|
|
results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafe_states + 0.0*results[-2].safe_states)
|
|
results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafe_states + 1.0*results[-2].safe_states)
|
|
print(TestResult.csv_header())
|
|
for result in results[:-1]:
|
|
print(result.csv())
|
|
|
|
|
|
with open(f"data_new_method_{factor_tests_per_cluster}", "w") as f:
|
|
f.write(TestResult.csv_header() + "\n")
|
|
for result in results[:-1]:
|
|
f.write(result.csv() + "\n")
|
|
|
|
|
|
_init_logger()
|
|
logger = logging.getLogger('main')
|
|
if __name__ == '__main__':
|
|
for factor_tests_per_cluster in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
|
|
run_experiment(factor_tests_per_cluster)
|