Browse Source

redid querying

add_velocity_into_framework
sp 8 months ago
parent
commit
019ea0ad1e
  1. 86
      rom_evaluate.py

86
rom_evaluate.py

@ -1,6 +1,7 @@
import sys import sys
import operator import operator
from os import listdir, system from os import listdir, system
import subprocess
import re import re
from collections import defaultdict from collections import defaultdict
@ -104,6 +105,7 @@ def saveObservations(observations, verdict, testDir):
img.save(f"{testDir}/{i:003}.png") img.save(f"{testDir}/{i:003}.png")
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
#def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): #def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50): def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
@ -119,19 +121,18 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
all_obs = list() all_obs = list()
speed_list = list() speed_list = list()
first_action_set = False
first_action = 0
for i in range(0,duration):
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
if len(all_obs) >= 4:
for i in range(0,4):
all_obs.append(resized_obs)
for i in range(0,duration-4):
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
all_obs.append(resized_obs)
if i % 4 == 0:
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
action = nn_wrapper.query(stack_tensor) action = nn_wrapper.query(stack_tensor)
if not first_action_set:
first_action_set = True
first_action = input_to_action(str(action))
ale.act(input_to_action(str(action))) ale.act(input_to_action(str(action)))
else: else:
ale.act(Action.NOOP)
ale.act(input_to_action(str(action)))
speed_list.append(ale.getRAM()[14]) speed_list.append(ale.getRAM()[14])
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
saveObservations(all_obs, Verdict.BAD, testDir) saveObservations(all_obs, Verdict.BAD, testDir)
@ -139,14 +140,16 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
saveObservations(all_obs, Verdict.GOOD, testDir) saveObservations(all_obs, Verdict.GOOD, testDir)
return Verdict.GOOD return Verdict.GOOD
def optimalAction(choices):
return max(choices.items(), key=operator.itemgetter(1))[0]
def computeStateRanking(mdp_file): def computeStateRanking(mdp_file):
logger.info("Computing state ranking") logger.info("Computing state ranking")
tic() tic()
try:
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop 'Rmax=? [C <= 1000]'" command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop 'Rmax=? [C <= 1000]'"
exec(command)
result = subprocess.run(command, shell=True, check=True)
print(result)
except Exception as e:
print(e)
sys.exit(-1)
logger.info(f"Computing state ranking - DONE: took {toc()} seconds") logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
def fillStateRanking(file_name, match=""): def fillStateRanking(file_name, match=""):
@ -221,26 +224,37 @@ def clusterFormula(cluster):
formula += ")" formula += ")"
return formula return formula
def createBalancedDisjunction(indices, name):
#logger.info(f"Creating balanced disjunction for {len(indices)} ({indices}) formulas")
if len(indices) == 0:
return f"formula {name} = false;\n"
else:
while len(indices) > 1:
indices_tmp = [f"({indices[i]} | {indices[i+1]})" for i in range(0,len(indices)//2)]
if len(indices) % 2 == 1:
indices_tmp.append(indices[-1])
indices = indices_tmp
disjunction = f"formula {name} = " + " ".join(indices) + ";\n"
return disjunction
def createUnsafeFormula(clusters): def createUnsafeFormula(clusters):
label = "label \"Unsafe\" = Unsafe;\n"
formulas = "" formulas = ""
disjunction = "formula Unsafe = false"
indices = list()
for i, cluster in enumerate(clusters): for i, cluster in enumerate(clusters):
formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n" formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n"
clusterFormula(cluster)
disjunction += f" | Unsafe_{i}"
disjunction += ";\n"
label = "label \"Unsafe\" = Unsafe;\n"
return formulas + "\n" + disjunction + label
indices.append(f"Unsafe_{i}")
return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe") + label
def createSafeFormula(clusters): def createSafeFormula(clusters):
label = "label \"Safe\" = Safe;\n"
formulas = "" formulas = ""
disjunction = "formula Safe = false"
indices = list()
for i, cluster in enumerate(clusters): for i, cluster in enumerate(clusters):
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
disjunction += f" | Safe_{i}"
disjunction += ";\n"
label = "label \"Safe\" = Safe;\n"
return formulas + "\n" + disjunction + label
indices.append(f"Safe_{i}")
return formulas + "\n" + createBalancedDisjunction(indices, "Safe") + label
def updatePrismFile(newFile, iteration, safeStates, unsafeStates): def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
logger.info("Creating next prism file") logger.info("Creating next prism file")
@ -285,8 +299,8 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
for state in states: for state in states:
s = state[0] s = state[0]
#marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
#marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
markerList[s.ski_position].append(marker) markerList[s.ski_position].append(marker)
for pos, marker in markerList.items(): for pos, marker in markerList.items():
command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png" command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png"
@ -344,17 +358,21 @@ def _init_logger():
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
def clusterImportantStates(ranking, iteration, n_clusters=40):
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
def clusterImportantStates(ranking, iteration):
logger.info(f"Starting to cluster {len(ranking)} states into clusters")
tic() tic()
#states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[0].velocity * 10, s[1].ranking] for s in ranking] #states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[0].velocity * 10, s[1].ranking] for s in ranking]
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
#dbscan = DBSCAN().fit(states)
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
#kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
dbscan = DBSCAN(eps=15).fit(states)
labels = dbscan.labels_
print(labels)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
clusterDict = {i : list() for i in range(0,n_clusters)} clusterDict = {i : list() for i in range(0,n_clusters)}
for i, state in enumerate(ranking): for i, state in enumerate(ranking):
clusterDict[kmeans.labels_[i]].append(state)
if labels[i] == -1: continue
clusterDict[labels[i]].append(state)
drawClusters(clusterDict, f"clusters", iteration) drawClusters(clusterDict, f"clusters", iteration)
return clusterDict return clusterDict
@ -373,7 +391,7 @@ if __name__ == '__main__':
computeStateRanking(f"{init_mdp}_{iteration:03}.prism") computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking") ranking = fillStateRanking("action_ranking")
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters)
clusters = clusterImportantStates(sorted_ranking, iteration)
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
clusterResult = dict() clusterResult = dict()
@ -401,7 +419,7 @@ if __name__ == '__main__':
if verdictGood: if verdictGood:
clusterResult[id] = (cluster, Verdict.GOOD) clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.append(cluster) safeStates.append(cluster)
if testAll: drawClusters(failingPerCluster, f"failing")
if testAll: drawClusters(failingPerCluster, f"failing", iteration)
drawResult(clusterResult, "result", iteration) drawResult(clusterResult, "result", iteration)
iteration += 1 iteration += 1

Loading…
Cancel
Save