Browse Source

started working on clustering via kmeans

add_velocity_into_framework
sp 8 months ago
parent
commit
e7b5f8344b
  1. 126
      rom_evaluate.py

126
rom_evaluate.py

@ -1,9 +1,10 @@
import sys import sys
import operator import operator
from os import listdir, system from os import listdir, system
import re
from random import randrange from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT, Action from ale_py import ALEInterface, SDL_SUPPORT, Action
from colors import *
#from colors import *
from PIL import Image from PIL import Image
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import cv2 import cv2
@ -11,13 +12,18 @@ import pickle
import queue import queue
from dataclasses import dataclass, field from dataclasses import dataclass, field
from sklearn.cluster import KMeans
from enum import Enum from enum import Enum
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
import readchar
import logging
logger = logging.getLogger(__name__)
#import readchar
from sample_factory.algo.utils.tensor_dict import TensorDict from sample_factory.algo.utils.tensor_dict import TensorDict
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
@ -25,9 +31,17 @@ from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
import time import time
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
def tic():
import time
global startTime_for_tictoc
startTime_for_tictoc = time.time()
def toc():
import time
if 'startTime_for_tictoc' in globals():
return time.time() - startTime_for_tictoc
class Verdict(Enum): class Verdict(Enum):
INCONCLUSIVE = 1 INCONCLUSIVE = 1
@ -56,6 +70,7 @@ def exec(command,verbose=True):
system(f"echo {command} >> list_of_exec") system(f"echo {command} >> list_of_exec")
return system(command) return system(command)
num_ski_positions = 8
def model_to_actual(ski_position): def model_to_actual(ski_position):
if ski_position == 1: if ski_position == 1:
return 1 return 1
@ -140,31 +155,42 @@ def optimalAction(choices):
return max(choices.items(), key=operator.itemgetter(1))[0] return max(choices.items(), key=operator.itemgetter(1))[0]
def computeStateRanking(mdp_file): def computeStateRanking(mdp_file):
logger.info("Computing state ranking")
tic()
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'" command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'"
exec(command) exec(command)
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
def fillStateRanking(file_name, match=""): def fillStateRanking(file_name, match=""):
logger.info("Parsing state ranking")
tic()
state_ranking = dict() state_ranking = dict()
try: try:
with open(file_name, "r") as f: with open(file_name, "r") as f:
file_content = f.readlines() file_content = f.readlines()
for line in file_content: for line in file_content:
if not "move=0" in line: continue if not "move=0" in line: continue
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
if ranking_value <= 0.1:
continue
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
#print("stateMapping", stateMapping) #print("stateMapping", stateMapping)
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
choices = {key:float(value) for (key,value) in choices.items()} choices = {key:float(value) for (key,value) in choices.items()}
#print("choices", choices) #print("choices", choices)
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
#print("ranking_value", ranking_value) #print("ranking_value", ranking_value)
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
value = StateValue(ranking_value, choices) value = StateValue(ranking_value, choices)
state_ranking[state] = value state_ranking[state] = value
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
return state_ranking return state_ranking
except EnvironmentError: except EnvironmentError:
print("Ranking file not available. Exiting.") print("Ranking file not available. Exiting.")
toc()
sys.exit(1) sys.exit(1)
except:
toc()
fixed_left_states = list() fixed_left_states = list()
@ -226,13 +252,94 @@ exec(f"cp 1_full_scaled_down.png images/testing_{id}/testing_0000.png")
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism") exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
markerSize = 1 markerSize = 1
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
def f(n):
if n >= 1.0:
return True
return False
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
for state in states:
s = state[0]
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
markerList[s.ski_position].append(marker)
for pos, marker in markerList.items():
command = f"convert images/testing_{id}/{target_prefix}_{pos:02}.png {' '.join(marker)} images/testing_{id}/{target_prefix}_{pos:02}.png"
exec(command, verbose=False)
while True:
computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
def concatImages(prefix):
exec(f"montage images/testing_{id}/{prefix}_*png -geometry +0+0 -tile x1 images/testing_{id}/{prefix}.png", verbose=False)
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
"""
Useful to draw a set of states, e.g. a single cluster
"""
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
logger.info(f"Drawing {len(states)} states onto {target}")
tic()
for state in states:
s = state[0]
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
markerList[s.ski_position].append(marker)
for pos, marker in markerList.items():
command = f"convert {source} {' '.join(marker)} images/testing_{id}/{target}_{pos:02}.png"
exec(command, verbose=False)
exec(f"montage images/testing_{id}/{target}_*png -geometry +0+0 -tile x1 images/testing_{id}/{target}.png", verbose=False)
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
def drawClusters(clusterDict, target, alpha_factor=1.0):
for ski_position in range(1, num_ski_positions + 1):
source = "images/1_full_scaled_down.png"
exec(f"cp {source} images/testing_{id}/{target}_{ski_position:02}.png")
for _, clusterStates in clusterDict.items():
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
drawOntoSkiPosImage(clusterStates, color, f"clusters")
concatImages("clusters")
def _init_logger():
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
def clusterImportantStates(ranking, n_clusters=10):
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
tic()
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
clusterDict = {i : list() for i in range(0,n_clusters)}
for i, state in enumerate(ranking):
clusterDict[kmeans.labels_[i]].append(state)
drawClusters(clusterDict, f"clusters")
return clusterDict
if __name__ == '__main__':
_init_logger()
logger = logging.getLogger('main')
logger.info("Starting")
while True:
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
ranking = fillStateRanking("action_ranking") ranking = fillStateRanking("action_ranking")
sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking)
for important_state in sorted_ranking[-100:-1]:
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
print(type(sorted_ranking))
clusters = clusterImportantStates(sorted_ranking)
sys.exit(1)
#for i, state in enumerate(sorted_ranking):
# print(state)
# if i % 10 == 0:
# input("")
#print(len(sorted_ranking))
"""
for important_state in ranking[-100:-1]:
optimal_choice = optimalAction(important_state[1].choices) optimal_choice = optimalAction(important_state[1].choices)
#print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}") #print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}")
x = important_state[0].x x = important_state[0].x
@ -248,4 +355,5 @@ while True:
exec(command, verbose=False) exec(command, verbose=False)
exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False) exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False)
iteration += 1 iteration += 1
"""
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism") update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")
Loading…
Cancel
Save