You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
359 lines
14 KiB
359 lines
14 KiB
import sys
|
|
import operator
|
|
from os import listdir, system
|
|
import re
|
|
from random import randrange
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action
|
|
#from colors import *
|
|
from PIL import Image
|
|
from matplotlib import pyplot as plt
|
|
import cv2
|
|
import pickle
|
|
import queue
|
|
from dataclasses import dataclass, field
|
|
|
|
from sklearn.cluster import KMeans
|
|
|
|
from enum import Enum
|
|
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
#import readchar
|
|
|
|
from sample_factory.algo.utils.tensor_dict import TensorDict
|
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
|
|
|
|
import time
|
|
|
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
|
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
|
|
|
|
def tic():
|
|
import time
|
|
global startTime_for_tictoc
|
|
startTime_for_tictoc = time.time()
|
|
|
|
def toc():
|
|
import time
|
|
if 'startTime_for_tictoc' in globals():
|
|
return time.time() - startTime_for_tictoc
|
|
|
|
class Verdict(Enum):
|
|
INCONCLUSIVE = 1
|
|
GOOD = 2
|
|
BAD = 3
|
|
|
|
verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
|
|
|
|
def convert(tuples):
|
|
return dict(tuples)
|
|
|
|
@dataclass(frozen=True)
|
|
class State:
|
|
x: int
|
|
y: int
|
|
ski_position: int
|
|
def default_value():
|
|
return {'action' : None, 'choiceValue' : None}
|
|
@dataclass(frozen=True)
|
|
class StateValue:
|
|
ranking: float
|
|
choices: dict = field(default_factory=default_value)
|
|
|
|
def exec(command,verbose=True):
|
|
if verbose: print(f"Executing {command}")
|
|
system(f"echo {command} >> list_of_exec")
|
|
return system(command)
|
|
|
|
num_ski_positions = 8
|
|
def model_to_actual(ski_position):
|
|
if ski_position == 1:
|
|
return 1
|
|
elif ski_position in [2,3]:
|
|
return 2
|
|
elif ski_position in [4,5]:
|
|
return 3
|
|
elif ski_position in [6,7]:
|
|
return 4
|
|
elif ski_position in [8,9]:
|
|
return 5
|
|
elif ski_position in [10,11]:
|
|
return 6
|
|
elif ski_position in [12,13]:
|
|
return 7
|
|
elif ski_position == 14:
|
|
return 8
|
|
|
|
def input_to_action(char):
|
|
if char == "0":
|
|
return Action.NOOP
|
|
if char == "1":
|
|
return Action.RIGHT
|
|
if char == "2":
|
|
return Action.LEFT
|
|
if char == "3":
|
|
return "reset"
|
|
if char == "4":
|
|
return "set_x"
|
|
if char == "5":
|
|
return "set_vel"
|
|
if char in ["w", "a", "s", "d"]:
|
|
return char
|
|
|
|
def drawImportantStates(important_states):
|
|
draw_commands = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list(), 9:list(), 10:list(), 11:list(), 12:list(), 13:list(), 14:list()}
|
|
for state in important_states:
|
|
x = state[0].x
|
|
y = state[0].y
|
|
markerSize = 2
|
|
ski_position = state[0].ski_position
|
|
draw_commands[ski_position].append(f"-fill 'rgba(255,204,0,{state[1].ranking})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '")
|
|
for i in range(1,15):
|
|
command = f"convert images/1_full_scaled_down.png {' '.join(draw_commands[i])} first_try_{i:02}.png"
|
|
exec(command)
|
|
|
|
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
|
|
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
|
|
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
|
|
for i, r in enumerate(ramDICT[y]):
|
|
ale.setRAM(i,r)
|
|
ski_position_setting = ski_position_counter[ski_position]
|
|
for i in range(0,ski_position_setting[1]):
|
|
ale.act(ski_position_setting[0])
|
|
ale.setRAM(14,0)
|
|
ale.setRAM(25,x)
|
|
ale.setRAM(14,180)
|
|
|
|
all_obs = list()
|
|
speed_list = list()
|
|
first_action_set = False
|
|
first_action = 0
|
|
for i in range(0,duration):
|
|
resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA)
|
|
all_obs.append(resized_obs)
|
|
if len(all_obs) >= 4:
|
|
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
|
|
action = nn_wrapper.query(stack_tensor)
|
|
if not first_action_set:
|
|
first_action_set = True
|
|
first_action = input_to_action(str(action))
|
|
ale.act(input_to_action(str(action)))
|
|
else:
|
|
ale.act(Action.NOOP)
|
|
speed_list.append(ale.getRAM()[14])
|
|
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
|
|
return (Verdict.BAD, first_action)
|
|
#time.sleep(0.005)
|
|
return (Verdict.INCONCLUSIVE, first_action)
|
|
|
|
def optimalAction(choices):
|
|
return max(choices.items(), key=operator.itemgetter(1))[0]
|
|
|
|
def computeStateRanking(mdp_file):
|
|
logger.info("Computing state ranking")
|
|
tic()
|
|
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'"
|
|
exec(command)
|
|
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
|
|
|
|
def fillStateRanking(file_name, match=""):
|
|
logger.info("Parsing state ranking")
|
|
tic()
|
|
state_ranking = dict()
|
|
try:
|
|
with open(file_name, "r") as f:
|
|
file_content = f.readlines()
|
|
for line in file_content:
|
|
if not "move=0" in line: continue
|
|
ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
|
|
if ranking_value <= 0.1:
|
|
continue
|
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
|
|
#print("stateMapping", stateMapping)
|
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
|
|
choices = {key:float(value) for (key,value) in choices.items()}
|
|
#print("choices", choices)
|
|
#print("ranking_value", ranking_value)
|
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
|
|
value = StateValue(ranking_value, choices)
|
|
state_ranking[state] = value
|
|
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
|
|
return state_ranking
|
|
|
|
except EnvironmentError:
|
|
print("Ranking file not available. Exiting.")
|
|
toc()
|
|
sys.exit(1)
|
|
except:
|
|
toc()
|
|
|
|
|
|
fixed_left_states = list()
|
|
fixed_right_states = list()
|
|
fixed_noop_states = list()
|
|
|
|
def populate_fixed_actions(state, action):
|
|
if action == Action.LEFT:
|
|
fixed_left_states.append(state)
|
|
if action == Action.RIGHT:
|
|
fixed_right_states.append(state)
|
|
if action == Action.NOOP:
|
|
fixed_noop_states.append(state)
|
|
|
|
def update_prism_file(old_prism_file, new_prism_file):
|
|
fixed_left_formula = "formula Fixed_Left = false "
|
|
fixed_right_formula = "formula Fixed_Right = false "
|
|
fixed_noop_formula = "formula Fixed_Noop = false "
|
|
for state in fixed_left_states:
|
|
fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
|
|
for state in fixed_right_states:
|
|
fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
|
|
for state in fixed_noop_states:
|
|
fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
|
|
fixed_left_formula += ";\n"
|
|
fixed_right_formula += ";\n"
|
|
fixed_noop_formula += ";\n"
|
|
with open(f'{old_prism_file}', 'r') as file :
|
|
filedata = file.read()
|
|
if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE)
|
|
if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE)
|
|
if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE)
|
|
with open(f'{new_prism_file}', 'w') as file:
|
|
file.write(filedata)
|
|
|
|
ale = ALEInterface()
|
|
|
|
|
|
#if SDL_SUPPORT:
|
|
# ale.setBool("sound", True)
|
|
# ale.setBool("display_screen", True)
|
|
|
|
# Load the ROM file
|
|
ale.loadROM(rom_file)
|
|
|
|
with open('all_positions_v2.pickle', 'rb') as handle:
|
|
ramDICT = pickle.load(handle)
|
|
y_ram_setting = 60
|
|
x = 70
|
|
|
|
|
|
nn_wrapper = SampleFactoryNNQueryWrapper()
|
|
|
|
iteration = 0
|
|
id = int(time.time())
|
|
init_mdp = "velocity"
|
|
exec(f"mkdir -p images/testing_{id}")
|
|
exec(f"cp 1_full_scaled_down.png images/testing_{id}/testing_0000.png")
|
|
exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
|
|
|
|
markerSize = 1
|
|
#markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
|
|
|
|
def f(n):
|
|
if n >= 1.0:
|
|
return True
|
|
return False
|
|
|
|
def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
|
|
markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
|
|
for state in states:
|
|
s = state[0]
|
|
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
|
|
markerList[s.ski_position].append(marker)
|
|
for pos, marker in markerList.items():
|
|
command = f"convert images/testing_{id}/{target_prefix}_{pos:02}.png {' '.join(marker)} images/testing_{id}/{target_prefix}_{pos:02}.png"
|
|
exec(command, verbose=False)
|
|
|
|
|
|
def concatImages(prefix):
|
|
exec(f"montage images/testing_{id}/{prefix}_*png -geometry +0+0 -tile x1 images/testing_{id}/{prefix}.png", verbose=False)
|
|
|
|
def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
|
|
"""
|
|
Useful to draw a set of states, e.g. a single cluster
|
|
"""
|
|
markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
|
|
logger.info(f"Drawing {len(states)} states onto {target}")
|
|
tic()
|
|
for state in states:
|
|
s = state[0]
|
|
marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
|
|
markerList[s.ski_position].append(marker)
|
|
for pos, marker in markerList.items():
|
|
command = f"convert {source} {' '.join(marker)} images/testing_{id}/{target}_{pos:02}.png"
|
|
exec(command, verbose=False)
|
|
exec(f"montage images/testing_{id}/{target}_*png -geometry +0+0 -tile x1 images/testing_{id}/{target}.png", verbose=False)
|
|
logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
|
|
|
|
def drawClusters(clusterDict, target, alpha_factor=1.0):
|
|
for ski_position in range(1, num_ski_positions + 1):
|
|
source = "images/1_full_scaled_down.png"
|
|
exec(f"cp {source} images/testing_{id}/{target}_{ski_position:02}.png")
|
|
for _, clusterStates in clusterDict.items():
|
|
color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
|
|
drawOntoSkiPosImage(clusterStates, color, f"clusters")
|
|
concatImages("clusters")
|
|
|
|
|
|
def _init_logger():
|
|
logger = logging.getLogger('main')
|
|
logger.setLevel(logging.INFO)
|
|
handler = logging.StreamHandler(sys.stdout)
|
|
formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
def clusterImportantStates(ranking, n_clusters=10):
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
|
|
tic()
|
|
states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
|
|
kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
|
|
logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
|
|
clusterDict = {i : list() for i in range(0,n_clusters)}
|
|
for i, state in enumerate(ranking):
|
|
clusterDict[kmeans.labels_[i]].append(state)
|
|
drawClusters(clusterDict, f"clusters")
|
|
return clusterDict
|
|
|
|
if __name__ == '__main__':
|
|
_init_logger()
|
|
logger = logging.getLogger('main')
|
|
logger.info("Starting")
|
|
while True:
|
|
#computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
|
|
ranking = fillStateRanking("action_ranking")
|
|
sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
|
|
print(type(sorted_ranking))
|
|
clusters = clusterImportantStates(sorted_ranking)
|
|
|
|
sys.exit(1)
|
|
#for i, state in enumerate(sorted_ranking):
|
|
# print(state)
|
|
# if i % 10 == 0:
|
|
# input("")
|
|
#print(len(sorted_ranking))
|
|
|
|
"""
|
|
for important_state in ranking[-100:-1]:
|
|
optimal_choice = optimalAction(important_state[1].choices)
|
|
#print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}")
|
|
x = important_state[0].x
|
|
y = important_state[0].y
|
|
ski_pos = model_to_actual(important_state[0].ski_position)
|
|
result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
|
|
#print(f".... {result}")
|
|
marker = f"-fill 'rgba({verdict_to_color_map[result[0]],0.7})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '"
|
|
markerList[ski_pos].append(marker)
|
|
populate_fixed_actions(important_state[0], result[1])
|
|
for pos, marker in markerList.items():
|
|
command = f"convert images/testing_{id}/testing_0000.png {' '.join(marker)} images/testing_{id}/testing_{iteration+1:03}_{pos:02}.png"
|
|
exec(command, verbose=False)
|
|
exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False)
|
|
iteration += 1
|
|
"""
|
|
update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")
|