Browse Source

started working on iterative workflow

- cluster result to formula
- etc

TODO: Properly implement naming scheme for images
add_velocity_into_framework
sp 8 months ago
parent
commit
bf5b21872c
  1. 143
      rom_evaluate.py

143
rom_evaluate.py

@ -2,6 +2,8 @@ import sys
import operator
from os import listdir, system
import re
from collections import defaultdict
from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT, Action
#from colors import *
@ -190,39 +192,76 @@ def fillStateRanking(file_name, match=""):
except:
toc()
def createDisjunction(formulas):
return " | ".join(formulas)
def clusterFormula(cluster):
formula = ""
states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
skiPositionGroup = defaultdict(list)
for item in states:
skiPositionGroup[item[2]].append(item)
first = True
for skiPosition, group in skiPositionGroup.items():
formula += f"ski_position={skiPosition} & ("
yPosGroup = defaultdict(list)
for item in group:
yPosGroup[item[1]].append(item)
for y, y_group in yPosGroup.items():
if first:
first = False
else:
formula += " | "
sorted_y_group = sorted(y_group, key=lambda s: s[0])
formula += f"( y={y} & ("
current_x_min = sorted_y_group[0][0]
current_x = sorted_y_group[0][0]
x_ranges = list()
for state in sorted_y_group[1:-1]:
if state[0] - current_x == 1:
current_x = state[0]
else:
x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
current_x_min = state[0]
current_x = state[0]
x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
formula += " | ".join(x_ranges)
formula += ") )"
formula += ")"
return formula
def createUnsafeFormula(clusters):
formulas = ""
disjunction = "formula Unsafe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n"
clusterFormula(cluster)
disjunction += f" | Unsafe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def createSafeFormula(clusters):
formulas = ""
disjunction = "formula Safe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
disjunction += f" | Safe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
logger.info("Creating next prism file")
tic()
initFile = f"{newFile}_no_formulas.prism"
newFile = f"{newFile}_{iteration:03}.prism"
exec(f"cp {initFile} {newFile}", verbose=False)
with open(newFile, "a") as prism:
prism.write(createSafeFormula(safeStates))
prism.write(createUnsafeFormula(unsafeStates))
logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
fixed_left_states = list()
fixed_right_states = list()
fixed_noop_states = list()
def populate_fixed_actions(state, action):
if action == Action.LEFT:
fixed_left_states.append(state)
if action == Action.RIGHT:
fixed_right_states.append(state)
if action == Action.NOOP:
fixed_noop_states.append(state)
def update_prism_file(old_prism_file, new_prism_file):
fixed_left_formula = "formula Fixed_Left = false "
fixed_right_formula = "formula Fixed_Right = false "
fixed_noop_formula = "formula Fixed_Noop = false "
for state in fixed_left_states:
fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
for state in fixed_right_states:
fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
for state in fixed_noop_states:
fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
fixed_left_formula += ";\n"
fixed_right_formula += ";\n"
fixed_noop_formula += ";\n"
with open(f'{old_prism_file}', 'r') as file :
filedata = file.read()
if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE)
if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE)
if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE)
with open(f'{new_prism_file}', 'w') as file:
file.write(filedata)
ale = ALEInterface()
@ -245,9 +284,9 @@ nn_wrapper = SampleFactoryNNQueryWrapper()
iteration = 0
experiment_id = int(time.time())
init_mdp = "velocity_safety"
exec(f"mkdir -p images/testing_{experiment_id}")
exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png")
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png")
#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
markerSize = 1
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
@ -265,9 +304,9 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.
exec(command, verbose=False)
def concatImages(prefix):
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}.png", verbose=False)
exec(f"sxiv {imagesDir}/{prefix}.png&")
def concatImages(prefix, iteration):
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
#exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False)
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
"""
@ -286,16 +325,16 @@ def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled
exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
def drawClusters(clusterDict, target, alpha_factor=1.0):
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
for _, clusterStates in clusterDict.items():
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
concatImages(target)
concatImages(target, iteration)
def drawResult(clusterDict, target):
def drawResult(clusterDict, target, iteration):
for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
@ -306,7 +345,7 @@ def drawResult(clusterDict, target):
elif result == Verdict.BAD:
color = "200,0,0"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
concatImages(target)
concatImages(target, iteration)
def _init_logger():
logger = logging.getLogger('main')
@ -316,7 +355,7 @@ def _init_logger():
handler.setFormatter(formatter)
logger.addHandler(handler)
def clusterImportantStates(ranking, n_clusters=40):
def clusterImportantStates(ranking, iteration, n_clusters=40):
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
tic()
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
@ -325,20 +364,25 @@ def clusterImportantStates(ranking, n_clusters=40):
clusterDict = {i : list() for i in range(0,n_clusters)}
for i, state in enumerate(ranking):
clusterDict[kmeans.labels_[i]].append(state)
drawClusters(clusterDict, f"clusters")
drawClusters(clusterDict, f"clusters", iteration)
return clusterDict
if __name__ == '__main__':
_init_logger()
logger = logging.getLogger('main')
logger.info("Starting")
n_clusters = 40
n_clusters = 2
testAll = False
safeStates = list()
unsafeStates = list()
iteration = 0
while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking")
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
clusters = clusterImportantStates(sorted_ranking, n_clusters)
clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters)
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
clusterResult = dict()
@ -361,10 +405,13 @@ if __name__ == '__main__':
else:
clusterResult[id] = (cluster, Verdict.BAD)
verdictGood = False
unsafeStates.append(cluster)
break
if verdictGood:
clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.append(cluster)
if testAll: drawClusters(failingPerCluster, f"failing")
drawResult(clusterResult, "result")
drawResult(clusterResult, "result", iteration)
iteration += 1
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")
Loading…
Cancel
Save