|
@ -1,9 +1,10 @@ |
|
|
import sys |
|
|
import sys |
|
|
import operator |
|
|
import operator |
|
|
from os import listdir, system |
|
|
from os import listdir, system |
|
|
|
|
|
import re |
|
|
from random import randrange |
|
|
from random import randrange |
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
|
|
from colors import * |
|
|
|
|
|
|
|
|
#from colors import * |
|
|
from PIL import Image |
|
|
from PIL import Image |
|
|
from matplotlib import pyplot as plt |
|
|
from matplotlib import pyplot as plt |
|
|
import cv2 |
|
|
import cv2 |
|
@ -11,13 +12,18 @@ import pickle |
|
|
import queue |
|
|
import queue |
|
|
from dataclasses import dataclass, field |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
|
|
|
|
from sklearn.cluster import KMeans |
|
|
|
|
|
|
|
|
from enum import Enum |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
import readchar |
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
#import readchar |
|
|
|
|
|
|
|
|
from sample_factory.algo.utils.tensor_dict import TensorDict |
|
|
from sample_factory.algo.utils.tensor_dict import TensorDict |
|
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
|
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
|
@ -25,9 +31,17 @@ from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
|
|
import time |
|
|
import time |
|
|
|
|
|
|
|
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" |
|
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" |
|
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" |
|
|
|
|
|
|
|
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin" |
|
|
|
|
|
|
|
|
|
|
|
def tic(): |
|
|
|
|
|
import time |
|
|
|
|
|
global startTime_for_tictoc |
|
|
|
|
|
startTime_for_tictoc = time.time() |
|
|
|
|
|
|
|
|
|
|
|
def toc(): |
|
|
|
|
|
import time |
|
|
|
|
|
if 'startTime_for_tictoc' in globals(): |
|
|
|
|
|
return time.time() - startTime_for_tictoc |
|
|
|
|
|
|
|
|
class Verdict(Enum): |
|
|
class Verdict(Enum): |
|
|
INCONCLUSIVE = 1 |
|
|
INCONCLUSIVE = 1 |
|
@ -56,6 +70,7 @@ def exec(command,verbose=True): |
|
|
system(f"echo {command} >> list_of_exec") |
|
|
system(f"echo {command} >> list_of_exec") |
|
|
return system(command) |
|
|
return system(command) |
|
|
|
|
|
|
|
|
|
|
|
num_ski_positions = 8 |
|
|
def model_to_actual(ski_position): |
|
|
def model_to_actual(ski_position): |
|
|
if ski_position == 1: |
|
|
if ski_position == 1: |
|
|
return 1 |
|
|
return 1 |
|
@ -140,31 +155,42 @@ def optimalAction(choices): |
|
|
return max(choices.items(), key=operator.itemgetter(1))[0] |
|
|
return max(choices.items(), key=operator.itemgetter(1))[0] |
|
|
|
|
|
|
|
|
def computeStateRanking(mdp_file): |
|
|
def computeStateRanking(mdp_file): |
|
|
|
|
|
logger.info("Computing state ranking") |
|
|
|
|
|
tic() |
|
|
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'" |
|
|
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'" |
|
|
exec(command) |
|
|
exec(command) |
|
|
|
|
|
logger.info(f"Computing state ranking - DONE: took {toc()} seconds") |
|
|
|
|
|
|
|
|
def fillStateRanking(file_name, match=""): |
|
|
def fillStateRanking(file_name, match=""): |
|
|
|
|
|
logger.info("Parsing state ranking") |
|
|
|
|
|
tic() |
|
|
state_ranking = dict() |
|
|
state_ranking = dict() |
|
|
try: |
|
|
try: |
|
|
with open(file_name, "r") as f: |
|
|
with open(file_name, "r") as f: |
|
|
file_content = f.readlines() |
|
|
file_content = f.readlines() |
|
|
for line in file_content: |
|
|
for line in file_content: |
|
|
if not "move=0" in line: continue |
|
|
if not "move=0" in line: continue |
|
|
|
|
|
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) |
|
|
|
|
|
if ranking_value <= 0.1: |
|
|
|
|
|
continue |
|
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
|
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
|
|
#print("stateMapping", stateMapping) |
|
|
#print("stateMapping", stateMapping) |
|
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
|
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
|
|
choices = {key:float(value) for (key,value) in choices.items()} |
|
|
choices = {key:float(value) for (key,value) in choices.items()} |
|
|
#print("choices", choices) |
|
|
#print("choices", choices) |
|
|
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:","")) |
|
|
|
|
|
#print("ranking_value", ranking_value) |
|
|
#print("ranking_value", ranking_value) |
|
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
|
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
|
|
value = StateValue(ranking_value, choices) |
|
|
value = StateValue(ranking_value, choices) |
|
|
state_ranking[state] = value |
|
|
state_ranking[state] = value |
|
|
|
|
|
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") |
|
|
return state_ranking |
|
|
return state_ranking |
|
|
|
|
|
|
|
|
except EnvironmentError: |
|
|
except EnvironmentError: |
|
|
print("Ranking file not available. Exiting.") |
|
|
print("Ranking file not available. Exiting.") |
|
|
|
|
|
toc() |
|
|
sys.exit(1) |
|
|
sys.exit(1) |
|
|
|
|
|
except: |
|
|
|
|
|
toc() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fixed_left_states = list() |
|
|
fixed_left_states = list() |
|
@ -226,13 +252,94 @@ exec(f"cp 1_full_scaled_down.png images/testing_{id}/testing_0000.png") |
|
|
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") |
|
|
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") |
|
|
|
|
|
|
|
|
markerSize = 1 |
|
|
markerSize = 1 |
|
|
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|
|
|
|
|
|
|
|
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|
|
|
|
|
|
|
|
|
|
|
def f(n): |
|
|
|
|
|
if n >= 1.0: |
|
|
|
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0): |
|
|
|
|
|
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)} |
|
|
|
|
|
for state in states: |
|
|
|
|
|
s = state[0] |
|
|
|
|
|
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" |
|
|
|
|
|
markerList[s.ski_position].append(marker) |
|
|
|
|
|
for pos, marker in markerList.items(): |
|
|
|
|
|
command = f"convert images/testing_{id}/{target_prefix}_{pos:02}.png {' '.join(marker)} images/testing_{id}/{target_prefix}_{pos:02}.png" |
|
|
|
|
|
exec(command, verbose=False) |
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
|
computeStateRanking(f"{init_mdp}_{iteration:03}.prism") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def concatImages(prefix): |
|
|
|
|
|
exec(f"montage images/testing_{id}/{prefix}_*png -geometry +0+0 -tile x1 images/testing_{id}/{prefix}.png", verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0): |
|
|
|
|
|
""" |
|
|
|
|
|
Useful to draw a set of states, e.g. a single cluster |
|
|
|
|
|
""" |
|
|
|
|
|
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()} |
|
|
|
|
|
logger.info(f"Drawing {len(states)} states onto {target}") |
|
|
|
|
|
tic() |
|
|
|
|
|
for state in states: |
|
|
|
|
|
s = state[0] |
|
|
|
|
|
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '" |
|
|
|
|
|
markerList[s.ski_position].append(marker) |
|
|
|
|
|
for pos, marker in markerList.items(): |
|
|
|
|
|
command = f"convert {source} {' '.join(marker)} images/testing_{id}/{target}_{pos:02}.png" |
|
|
|
|
|
exec(command, verbose=False) |
|
|
|
|
|
exec(f"montage images/testing_{id}/{target}_*png -geometry +0+0 -tile x1 images/testing_{id}/{target}.png", verbose=False) |
|
|
|
|
|
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds") |
|
|
|
|
|
|
|
|
|
|
|
def drawClusters(clusterDict, target, alpha_factor=1.0): |
|
|
|
|
|
for ski_position in range(1, num_ski_positions + 1): |
|
|
|
|
|
source = "images/1_full_scaled_down.png" |
|
|
|
|
|
exec(f"cp {source} images/testing_{id}/{target}_{ski_position:02}.png") |
|
|
|
|
|
for _, clusterStates in clusterDict.items(): |
|
|
|
|
|
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}" |
|
|
|
|
|
drawOntoSkiPosImage(clusterStates, color, f"clusters") |
|
|
|
|
|
concatImages("clusters") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_logger(): |
|
|
|
|
|
logger = logging.getLogger('main') |
|
|
|
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
handler = logging.StreamHandler(sys.stdout) |
|
|
|
|
|
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s') |
|
|
|
|
|
handler.setFormatter(formatter) |
|
|
|
|
|
logger.addHandler(handler) |
|
|
|
|
|
|
|
|
|
|
|
def clusterImportantStates(ranking, n_clusters=10): |
|
|
|
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster") |
|
|
|
|
|
tic() |
|
|
|
|
|
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking] |
|
|
|
|
|
kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states) |
|
|
|
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds") |
|
|
|
|
|
clusterDict = {i : list() for i in range(0,n_clusters)} |
|
|
|
|
|
for i, state in enumerate(ranking): |
|
|
|
|
|
clusterDict[kmeans.labels_[i]].append(state) |
|
|
|
|
|
drawClusters(clusterDict, f"clusters") |
|
|
|
|
|
return clusterDict |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
_init_logger() |
|
|
|
|
|
logger = logging.getLogger('main') |
|
|
|
|
|
logger.info("Starting") |
|
|
|
|
|
while True: |
|
|
|
|
|
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism") |
|
|
ranking = fillStateRanking("action_ranking") |
|
|
ranking = fillStateRanking("action_ranking") |
|
|
sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking) |
|
|
|
|
|
for important_state in sorted_ranking[-100:-1]: |
|
|
|
|
|
|
|
|
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking) |
|
|
|
|
|
print(type(sorted_ranking)) |
|
|
|
|
|
clusters = clusterImportantStates(sorted_ranking) |
|
|
|
|
|
|
|
|
|
|
|
sys.exit(1) |
|
|
|
|
|
#for i, state in enumerate(sorted_ranking): |
|
|
|
|
|
# print(state) |
|
|
|
|
|
# if i % 10 == 0: |
|
|
|
|
|
# input("") |
|
|
|
|
|
#print(len(sorted_ranking)) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
for important_state in ranking[-100:-1]: |
|
|
optimal_choice = optimalAction(important_state[1].choices) |
|
|
optimal_choice = optimalAction(important_state[1].choices) |
|
|
#print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}") |
|
|
#print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}") |
|
|
x = important_state[0].x |
|
|
x = important_state[0].x |
|
@ -248,4 +355,5 @@ while True: |
|
|
exec(command, verbose=False) |
|
|
exec(command, verbose=False) |
|
|
exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False) |
|
|
exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False) |
|
|
iteration += 1 |
|
|
iteration += 1 |
|
|
|
|
|
""" |
|
|
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism") |
|
|
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism") |