Browse Source

started working on iterative workflow

- cluster result to formula
- etc

TODO: Properly implement naming scheme for images
add_velocity_into_framework
sp 6 months ago
parent
commit
bf5b21872c
  1. 143
      rom_evaluate.py

143
rom_evaluate.py

@ -2,6 +2,8 @@ import sys
import operator import operator
from os import listdir, system from os import listdir, system
import re import re
from collections import defaultdict
from random import randrange from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT, Action from ale_py import ALEInterface, SDL_SUPPORT, Action
#from colors import * #from colors import *
@ -190,39 +192,76 @@ def fillStateRanking(file_name, match=""):
except: except:
toc() toc()
def createDisjunction(formulas):
return " | ".join(formulas)
def clusterFormula(cluster):
formula = ""
states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
skiPositionGroup = defaultdict(list)
for item in states:
skiPositionGroup[item[2]].append(item)
first = True
for skiPosition, group in skiPositionGroup.items():
formula += f"ski_position={skiPosition} & ("
yPosGroup = defaultdict(list)
for item in group:
yPosGroup[item[1]].append(item)
for y, y_group in yPosGroup.items():
if first:
first = False
else:
formula += " | "
sorted_y_group = sorted(y_group, key=lambda s: s[0])
formula += f"( y={y} & ("
current_x_min = sorted_y_group[0][0]
current_x = sorted_y_group[0][0]
x_ranges = list()
for state in sorted_y_group[1:-1]:
if state[0] - current_x == 1:
current_x = state[0]
else:
x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
current_x_min = state[0]
current_x = state[0]
x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
formula += " | ".join(x_ranges)
formula += ") )"
formula += ")"
return formula
def createUnsafeFormula(clusters):
formulas = ""
disjunction = "formula Unsafe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n"
clusterFormula(cluster)
disjunction += f" | Unsafe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def createSafeFormula(clusters):
formulas = ""
disjunction = "formula Safe = false"
for i, cluster in enumerate(clusters):
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
disjunction += f" | Safe_{i}"
disjunction += ";\n"
return formulas + "\n" + disjunction
def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
logger.info("Creating next prism file")
tic()
initFile = f"{newFile}_no_formulas.prism"
newFile = f"{newFile}_{iteration:03}.prism"
exec(f"cp {initFile} {newFile}", verbose=False)
with open(newFile, "a") as prism:
prism.write(createSafeFormula(safeStates))
prism.write(createUnsafeFormula(unsafeStates))
logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
fixed_left_states = list()
fixed_right_states = list()
fixed_noop_states = list()
def populate_fixed_actions(state, action):
if action == Action.LEFT:
fixed_left_states.append(state)
if action == Action.RIGHT:
fixed_right_states.append(state)
if action == Action.NOOP:
fixed_noop_states.append(state)
def update_prism_file(old_prism_file, new_prism_file):
fixed_left_formula = "formula Fixed_Left = false "
fixed_right_formula = "formula Fixed_Right = false "
fixed_noop_formula = "formula Fixed_Noop = false "
for state in fixed_left_states:
fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
for state in fixed_right_states:
fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
for state in fixed_noop_states:
fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
fixed_left_formula += ";\n"
fixed_right_formula += ";\n"
fixed_noop_formula += ";\n"
with open(f'{old_prism_file}', 'r') as file :
filedata = file.read()
if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE)
if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE)
if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE)
with open(f'{new_prism_file}', 'w') as file:
file.write(filedata)
ale = ALEInterface() ale = ALEInterface()
@ -245,9 +284,9 @@ nn_wrapper = SampleFactoryNNQueryWrapper()
iteration = 0 iteration = 0
experiment_id = int(time.time()) experiment_id = int(time.time())
init_mdp = "velocity_safety" init_mdp = "velocity_safety"
exec(f"mkdir -p images/testing_{experiment_id}")
exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png")
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png")
#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
markerSize = 1 markerSize = 1
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} #markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
@ -265,9 +304,9 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.
exec(command, verbose=False) exec(command, verbose=False)
def concatImages(prefix):
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}.png", verbose=False)
exec(f"sxiv {imagesDir}/{prefix}.png&")
def concatImages(prefix, iteration):
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
#exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False)
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
""" """
@ -286,16 +325,16 @@ def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled
exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
def drawClusters(clusterDict, target, alpha_factor=1.0):
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
for ski_position in range(1, num_ski_positions + 1): for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png" source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
for _, clusterStates in clusterDict.items(): for _, clusterStates in clusterDict.items():
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
concatImages(target)
concatImages(target, iteration)
def drawResult(clusterDict, target):
def drawResult(clusterDict, target, iteration):
for ski_position in range(1, num_ski_positions + 1): for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png" source = "images/1_full_scaled_down.png"
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png")
@ -306,7 +345,7 @@ def drawResult(clusterDict, target):
elif result == Verdict.BAD: elif result == Verdict.BAD:
color = "200,0,0" color = "200,0,0"
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
concatImages(target)
concatImages(target, iteration)
def _init_logger(): def _init_logger():
logger = logging.getLogger('main') logger = logging.getLogger('main')
@ -316,7 +355,7 @@ def _init_logger():
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
def clusterImportantStates(ranking, n_clusters=40):
def clusterImportantStates(ranking, iteration, n_clusters=40):
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
tic() tic()
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
@ -325,20 +364,25 @@ def clusterImportantStates(ranking, n_clusters=40):
clusterDict = {i : list() for i in range(0,n_clusters)} clusterDict = {i : list() for i in range(0,n_clusters)}
for i, state in enumerate(ranking): for i, state in enumerate(ranking):
clusterDict[kmeans.labels_[i]].append(state) clusterDict[kmeans.labels_[i]].append(state)
drawClusters(clusterDict, f"clusters")
drawClusters(clusterDict, f"clusters", iteration)
return clusterDict return clusterDict
if __name__ == '__main__': if __name__ == '__main__':
_init_logger() _init_logger()
logger = logging.getLogger('main') logger = logging.getLogger('main')
logger.info("Starting") logger.info("Starting")
n_clusters = 40
n_clusters = 2
testAll = False testAll = False
safeStates = list()
unsafeStates = list()
iteration = 0
while True: while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism") #computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking") ranking = fillStateRanking("action_ranking")
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
clusters = clusterImportantStates(sorted_ranking, n_clusters)
clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters)
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
clusterResult = dict() clusterResult = dict()
@ -361,10 +405,13 @@ if __name__ == '__main__':
else: else:
clusterResult[id] = (cluster, Verdict.BAD) clusterResult[id] = (cluster, Verdict.BAD)
verdictGood = False verdictGood = False
unsafeStates.append(cluster)
break break
if verdictGood: if verdictGood:
clusterResult[id] = (cluster, Verdict.GOOD) clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.append(cluster)
if testAll: drawClusters(failingPerCluster, f"failing") if testAll: drawClusters(failingPerCluster, f"failing")
drawResult(clusterResult, "result")
drawResult(clusterResult, "result", iteration)
iteration += 1
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")
Loading…
Cancel
Save