|
@ -2,6 +2,8 @@ import sys |
|
|
import operator |
|
|
import operator |
|
|
from os import listdir, system |
|
|
from os import listdir, system |
|
|
import re |
|
|
import re |
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
from random import randrange |
|
|
from random import randrange |
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
|
|
#from colors import * |
|
|
#from colors import * |
|
@ -190,39 +192,76 @@ def fillStateRanking(file_name, match=""): |
|
|
except: |
|
|
except: |
|
|
toc() |
|
|
toc() |
|
|
|
|
|
|
|
|
|
|
|
def createDisjunction(formulas): |
|
|
|
|
|
return " | ".join(formulas) |
|
|
|
|
|
|
|
|
|
|
|
def clusterFormula(cluster): |
|
|
|
|
|
formula = "" |
|
|
|
|
|
states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] |
|
|
|
|
|
skiPositionGroup = defaultdict(list) |
|
|
|
|
|
for item in states: |
|
|
|
|
|
skiPositionGroup[item[2]].append(item) |
|
|
|
|
|
|
|
|
|
|
|
first = True |
|
|
|
|
|
for skiPosition, group in skiPositionGroup.items(): |
|
|
|
|
|
formula += f"ski_position={skiPosition} & (" |
|
|
|
|
|
yPosGroup = defaultdict(list) |
|
|
|
|
|
for item in group: |
|
|
|
|
|
yPosGroup[item[1]].append(item) |
|
|
|
|
|
for y, y_group in yPosGroup.items(): |
|
|
|
|
|
if first: |
|
|
|
|
|
first = False |
|
|
|
|
|
else: |
|
|
|
|
|
formula += " | " |
|
|
|
|
|
sorted_y_group = sorted(y_group, key=lambda s: s[0]) |
|
|
|
|
|
formula += f"( y={y} & (" |
|
|
|
|
|
current_x_min = sorted_y_group[0][0] |
|
|
|
|
|
current_x = sorted_y_group[0][0] |
|
|
|
|
|
x_ranges = list() |
|
|
|
|
|
for state in sorted_y_group[1:-1]: |
|
|
|
|
|
if state[0] - current_x == 1: |
|
|
|
|
|
current_x = state[0] |
|
|
|
|
|
else: |
|
|
|
|
|
x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") |
|
|
|
|
|
current_x_min = state[0] |
|
|
|
|
|
current_x = state[0] |
|
|
|
|
|
x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") |
|
|
|
|
|
formula += " | ".join(x_ranges) |
|
|
|
|
|
formula += ") )" |
|
|
|
|
|
formula += ")" |
|
|
|
|
|
return formula |
|
|
|
|
|
|
|
|
|
|
|
def createUnsafeFormula(clusters): |
|
|
|
|
|
formulas = "" |
|
|
|
|
|
disjunction = "formula Unsafe = false" |
|
|
|
|
|
for i, cluster in enumerate(clusters): |
|
|
|
|
|
formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n" |
|
|
|
|
|
clusterFormula(cluster) |
|
|
|
|
|
disjunction += f" | Unsafe_{i}" |
|
|
|
|
|
disjunction += ";\n" |
|
|
|
|
|
return formulas + "\n" + disjunction |
|
|
|
|
|
|
|
|
|
|
|
def createSafeFormula(clusters): |
|
|
|
|
|
formulas = "" |
|
|
|
|
|
disjunction = "formula Safe = false" |
|
|
|
|
|
for i, cluster in enumerate(clusters): |
|
|
|
|
|
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n" |
|
|
|
|
|
disjunction += f" | Safe_{i}" |
|
|
|
|
|
disjunction += ";\n" |
|
|
|
|
|
return formulas + "\n" + disjunction |
|
|
|
|
|
|
|
|
|
|
|
def updatePrismFile(newFile, iteration, safeStates, unsafeStates): |
|
|
|
|
|
logger.info("Creating next prism file") |
|
|
|
|
|
tic() |
|
|
|
|
|
initFile = f"{newFile}_no_formulas.prism" |
|
|
|
|
|
newFile = f"{newFile}_{iteration:03}.prism" |
|
|
|
|
|
exec(f"cp {initFile} {newFile}", verbose=False) |
|
|
|
|
|
with open(newFile, "a") as prism: |
|
|
|
|
|
prism.write(createSafeFormula(safeStates)) |
|
|
|
|
|
prism.write(createUnsafeFormula(unsafeStates)) |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Creating next prism file - DONE: took {toc()} seconds") |
|
|
|
|
|
|
|
|
fixed_left_states = list() |
|
|
|
|
|
fixed_right_states = list() |
|
|
|
|
|
fixed_noop_states = list() |
|
|
|
|
|
|
|
|
|
|
|
def populate_fixed_actions(state, action): |
|
|
|
|
|
if action == Action.LEFT: |
|
|
|
|
|
fixed_left_states.append(state) |
|
|
|
|
|
if action == Action.RIGHT: |
|
|
|
|
|
fixed_right_states.append(state) |
|
|
|
|
|
if action == Action.NOOP: |
|
|
|
|
|
fixed_noop_states.append(state) |
|
|
|
|
|
|
|
|
|
|
|
def update_prism_file(old_prism_file, new_prism_file): |
|
|
|
|
|
fixed_left_formula = "formula Fixed_Left = false " |
|
|
|
|
|
fixed_right_formula = "formula Fixed_Right = false " |
|
|
|
|
|
fixed_noop_formula = "formula Fixed_Noop = false " |
|
|
|
|
|
for state in fixed_left_states: |
|
|
|
|
|
fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " |
|
|
|
|
|
for state in fixed_right_states: |
|
|
|
|
|
fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " |
|
|
|
|
|
for state in fixed_noop_states: |
|
|
|
|
|
fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) " |
|
|
|
|
|
fixed_left_formula += ";\n" |
|
|
|
|
|
fixed_right_formula += ";\n" |
|
|
|
|
|
fixed_noop_formula += ";\n" |
|
|
|
|
|
with open(f'{old_prism_file}', 'r') as file : |
|
|
|
|
|
filedata = file.read() |
|
|
|
|
|
if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE) |
|
|
|
|
|
if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE) |
|
|
|
|
|
if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE) |
|
|
|
|
|
with open(f'{new_prism_file}', 'w') as file: |
|
|
|
|
|
file.write(filedata) |
|
|
|
|
|
|
|
|
|
|
|
ale = ALEInterface() |
|
|
ale = ALEInterface() |
|
|
|
|
|
|
|
@ -245,9 +284,9 @@ nn_wrapper = SampleFactoryNNQueryWrapper() |
|
|
iteration = 0 |
|
|
iteration = 0 |
|
|
experiment_id = int(time.time()) |
|
|
experiment_id = int(time.time()) |
|
|
init_mdp = "velocity_safety" |
|
|
init_mdp = "velocity_safety" |
|
|
exec(f"mkdir -p images/testing_{experiment_id}") |
|
|
|
|
|
exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") |
|
|
|
|
|
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") |
|
|
|
|
|
|
|
|
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) |
|
|
|
|
|
#exec(f"cp 1_full_scaled_down.png images/testing_{experiment_id}/testing_0000.png") |
|
|
|
|
|
#exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") |
|
|
|
|
|
|
|
|
markerSize = 1 |
|
|
markerSize = 1 |
|
|
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|
|
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|
@ -265,9 +304,9 @@ def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1. |
|
|
exec(command, verbose=False) |
|
|
exec(command, verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def concatImages(prefix): |
|
|
|
|
|
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}.png", verbose=False) |
|
|
|
|
|
exec(f"sxiv {imagesDir}/{prefix}.png&") |
|
|
|
|
|
|
|
|
def concatImages(prefix, iteration): |
|
|
|
|
|
exec(f"montage {imagesDir}/{prefix}_*png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False) |
|
|
|
|
|
#exec(f"sxiv {imagesDir}/{prefix}.png&", verbose=False) |
|
|
|
|
|
|
|
|
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): |
|
|
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): |
|
|
""" |
|
|
""" |
|
@ -286,16 +325,16 @@ def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled |
|
|
exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) |
|
|
exec(f"montage {imagesDir}/{target}_*png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False) |
|
|
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") |
|
|
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") |
|
|
|
|
|
|
|
|
def drawClusters(clusterDict, target, alpha_factor=1.0): |
|
|
|
|
|
|
|
|
def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): |
|
|
for ski_position in range(1, num_ski_positions + 1): |
|
|
for ski_position in range(1, num_ski_positions + 1): |
|
|
source = "images/1_full_scaled_down.png" |
|
|
source = "images/1_full_scaled_down.png" |
|
|
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") |
|
|
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") |
|
|
for _, clusterStates in clusterDict.items(): |
|
|
for _, clusterStates in clusterDict.items(): |
|
|
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" |
|
|
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" |
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) |
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor) |
|
|
concatImages(target) |
|
|
|
|
|
|
|
|
concatImages(target, iteration) |
|
|
|
|
|
|
|
|
def drawResult(clusterDict, target): |
|
|
|
|
|
|
|
|
def drawResult(clusterDict, target, iteration): |
|
|
for ski_position in range(1, num_ski_positions + 1): |
|
|
for ski_position in range(1, num_ski_positions + 1): |
|
|
source = "images/1_full_scaled_down.png" |
|
|
source = "images/1_full_scaled_down.png" |
|
|
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") |
|
|
exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}.png") |
|
@ -306,7 +345,7 @@ def drawResult(clusterDict, target): |
|
|
elif result == Verdict.BAD: |
|
|
elif result == Verdict.BAD: |
|
|
color = "200,0,0" |
|
|
color = "200,0,0" |
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) |
|
|
drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7) |
|
|
concatImages(target) |
|
|
|
|
|
|
|
|
concatImages(target, iteration) |
|
|
|
|
|
|
|
|
def _init_logger(): |
|
|
def _init_logger(): |
|
|
logger = logging.getLogger('main') |
|
|
logger = logging.getLogger('main') |
|
@ -316,7 +355,7 @@ def _init_logger(): |
|
|
handler.setFormatter(formatter) |
|
|
handler.setFormatter(formatter) |
|
|
logger.addHandler(handler) |
|
|
logger.addHandler(handler) |
|
|
|
|
|
|
|
|
def clusterImportantStates(ranking, n_clusters=40): |
|
|
|
|
|
|
|
|
def clusterImportantStates(ranking, iteration, n_clusters=40): |
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") |
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") |
|
|
tic() |
|
|
tic() |
|
|
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] |
|
|
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] |
|
@ -325,20 +364,25 @@ def clusterImportantStates(ranking, n_clusters=40): |
|
|
clusterDict = {i : list() for i in range(0,n_clusters)} |
|
|
clusterDict = {i : list() for i in range(0,n_clusters)} |
|
|
for i, state in enumerate(ranking): |
|
|
for i, state in enumerate(ranking): |
|
|
clusterDict[kmeans.labels_[i]].append(state) |
|
|
clusterDict[kmeans.labels_[i]].append(state) |
|
|
drawClusters(clusterDict, f"clusters") |
|
|
|
|
|
|
|
|
drawClusters(clusterDict, f"clusters", iteration) |
|
|
return clusterDict |
|
|
return clusterDict |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
if __name__ == '__main__': |
|
|
_init_logger() |
|
|
_init_logger() |
|
|
logger = logging.getLogger('main') |
|
|
logger = logging.getLogger('main') |
|
|
logger.info("Starting") |
|
|
logger.info("Starting") |
|
|
n_clusters = 40 |
|
|
|
|
|
|
|
|
n_clusters = 2 |
|
|
testAll = False |
|
|
testAll = False |
|
|
|
|
|
|
|
|
|
|
|
safeStates = list() |
|
|
|
|
|
unsafeStates = list() |
|
|
|
|
|
iteration = 0 |
|
|
while True: |
|
|
while True: |
|
|
|
|
|
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) |
|
|
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism") |
|
|
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism") |
|
|
ranking = fillStateRanking("action_ranking") |
|
|
ranking = fillStateRanking("action_ranking") |
|
|
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) |
|
|
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) |
|
|
clusters = clusterImportantStates(sorted_ranking, n_clusters) |
|
|
|
|
|
|
|
|
clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters) |
|
|
|
|
|
|
|
|
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} |
|
|
if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)} |
|
|
clusterResult = dict() |
|
|
clusterResult = dict() |
|
@ -361,10 +405,13 @@ if __name__ == '__main__': |
|
|
else: |
|
|
else: |
|
|
clusterResult[id] = (cluster, Verdict.BAD) |
|
|
clusterResult[id] = (cluster, Verdict.BAD) |
|
|
verdictGood = False |
|
|
verdictGood = False |
|
|
|
|
|
unsafeStates.append(cluster) |
|
|
break |
|
|
break |
|
|
if verdictGood: |
|
|
if verdictGood: |
|
|
clusterResult[id] = (cluster, Verdict.GOOD) |
|
|
clusterResult[id] = (cluster, Verdict.GOOD) |
|
|
|
|
|
safeStates.append(cluster) |
|
|
if testAll: drawClusters(failingPerCluster, f"failing") |
|
|
if testAll: drawClusters(failingPerCluster, f"failing") |
|
|
drawResult(clusterResult, "result") |
|
|
|
|
|
|
|
|
drawResult(clusterResult, "result", iteration) |
|
|
|
|
|
iteration += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism") |
|
|
|