Browse Source

WIP rom_evaluate update

add_velocity_into_framework
sp 9 months ago
parent
commit
e04d141419
  1. 25
      rom_evaluate.py

25
rom_evaluate.py

@ -109,7 +109,8 @@ ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.L
#def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
testDir = f"{x}_{y}_{ski_position}"#_{velocity}"
#testDir = f"{x}_{y}_{ski_position}_{velocity}"
testDir = f"{x}_{y}_{ski_position}"
for i, r in enumerate(ramDICT[y]):
ale.setRAM(i,r)
ski_position_setting = ski_position_counter[ski_position]
@ -140,7 +141,7 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
saveObservations(all_obs, Verdict.GOOD, testDir)
return Verdict.GOOD
def computeStateRanking(mdp_file):
def computeStateRanking(mdp_file, iteration):
logger.info("Computing state ranking")
tic()
try:
@ -150,10 +151,11 @@ def computeStateRanking(mdp_file):
except Exception as e:
print(e)
sys.exit(-1)
exec(f"mv action_ranking action_ranking_{iteration:03}")
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
def fillStateRanking(file_name, match=""):
logger.info("Parsing state ranking")
logger.info(f"Parsing state ranking, {file_name}")
tic()
state_ranking = dict()
try:
@ -170,7 +172,8 @@ def fillStateRanking(file_name, match=""):
choices = {key:float(value) for (key,value) in choices.items()}
#print("choices", choices)
#print("ranking_value", ranking_value)
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))#, int(stateMapping["velocity"]))
#state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"]))
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
value = StateValue(ranking_value, choices)
state_ranking[state] = value
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
@ -299,8 +302,8 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
for state in states:
s = state[0]
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
#marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
#marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
markerList[s.ski_position].append(marker)
for pos, marker in markerList.items():
command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png"
@ -366,7 +369,6 @@ def clusterImportantStates(ranking, iteration):
#kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
dbscan = DBSCAN(eps=15).fit(states)
labels = dbscan.labels_
print(labels)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
clusterDict = {i : list() for i in range(0,n_clusters)}
@ -388,8 +390,8 @@ if __name__ == '__main__':
iteration = 0
while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking")
computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
ranking = fillStateRanking(f"action_ranking_{iteration:03}")
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
clusters = clusterImportantStates(sorted_ranking, iteration)
@ -397,6 +399,7 @@ if __name__ == '__main__':
clusterResult = dict()
for id, cluster in clusters.items():
num_tests = int(factor_tests_per_cluster * len(cluster))
num_tests = 1
logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
randomStates = np.random.choice(len(cluster), num_tests, replace=False)
randomStates = [cluster[i] for i in randomStates]
@ -408,6 +411,8 @@ if __name__ == '__main__':
ski_pos = state[0].ski_position
#velocity = state[0].velocity
result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
#result = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
result = Verdict.BAD # TODO REMOVE ME!!!!!!!!!!!!!!
if result == Verdict.BAD:
if testAll:
failingPerCluster[id].append(state)
@ -419,7 +424,9 @@ if __name__ == '__main__':
if verdictGood:
clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.append(cluster)
logger.info(f"Iteration: {iteration:03} -\tSafe Results : {sum([len(c) for c in safeStates])} -\tUnsafe Results:{sum([len(c) for c in unsafeStates])}")
if testAll: drawClusters(failingPerCluster, f"failing", iteration)
drawResult(clusterResult, "result", iteration)
iteration += 1

Loading…
Cancel
Save