You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

359 lines
14 KiB

  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import re
  5. from random import randrange
  6. from ale_py import ALEInterface, SDL_SUPPORT, Action
  7. #from colors import *
  8. from PIL import Image
  9. from matplotlib import pyplot as plt
  10. import cv2
  11. import pickle
  12. import queue
  13. from dataclasses import dataclass, field
  14. from sklearn.cluster import KMeans
  15. from enum import Enum
  16. from copy import deepcopy
  17. import numpy as np
  18. import logging
  19. logger = logging.getLogger(__name__)
  20. #import readchar
  21. from sample_factory.algo.utils.tensor_dict import TensorDict
  22. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  23. import time
  24. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  25. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  26. def tic():
  27. import time
  28. global startTime_for_tictoc
  29. startTime_for_tictoc = time.time()
  30. def toc():
  31. import time
  32. if 'startTime_for_tictoc' in globals():
  33. return time.time() - startTime_for_tictoc
  34. class Verdict(Enum):
  35. INCONCLUSIVE = 1
  36. GOOD = 2
  37. BAD = 3
  38. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  39. def convert(tuples):
  40. return dict(tuples)
  41. @dataclass(frozen=True)
  42. class State:
  43. x: int
  44. y: int
  45. ski_position: int
  46. def default_value():
  47. return {'action' : None, 'choiceValue' : None}
  48. @dataclass(frozen=True)
  49. class StateValue:
  50. ranking: float
  51. choices: dict = field(default_factory=default_value)
  52. def exec(command,verbose=True):
  53. if verbose: print(f"Executing {command}")
  54. system(f"echo {command} >> list_of_exec")
  55. return system(command)
  56. num_ski_positions = 8
  57. def model_to_actual(ski_position):
  58. if ski_position == 1:
  59. return 1
  60. elif ski_position in [2,3]:
  61. return 2
  62. elif ski_position in [4,5]:
  63. return 3
  64. elif ski_position in [6,7]:
  65. return 4
  66. elif ski_position in [8,9]:
  67. return 5
  68. elif ski_position in [10,11]:
  69. return 6
  70. elif ski_position in [12,13]:
  71. return 7
  72. elif ski_position == 14:
  73. return 8
  74. def input_to_action(char):
  75. if char == "0":
  76. return Action.NOOP
  77. if char == "1":
  78. return Action.RIGHT
  79. if char == "2":
  80. return Action.LEFT
  81. if char == "3":
  82. return "reset"
  83. if char == "4":
  84. return "set_x"
  85. if char == "5":
  86. return "set_vel"
  87. if char in ["w", "a", "s", "d"]:
  88. return char
  89. def drawImportantStates(important_states):
  90. draw_commands = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list(), 9:list(), 10:list(), 11:list(), 12:list(), 13:list(), 14:list()}
  91. for state in important_states:
  92. x = state[0].x
  93. y = state[0].y
  94. markerSize = 2
  95. ski_position = state[0].ski_position
  96. draw_commands[ski_position].append(f"-fill 'rgba(255,204,0,{state[1].ranking})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '")
  97. for i in range(1,15):
  98. command = f"convert images/1_full_scaled_down.png {' '.join(draw_commands[i])} first_try_{i:02}.png"
  99. exec(command)
  100. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  101. def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
  102. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  103. for i, r in enumerate(ramDICT[y]):
  104. ale.setRAM(i,r)
  105. ski_position_setting = ski_position_counter[ski_position]
  106. for i in range(0,ski_position_setting[1]):
  107. ale.act(ski_position_setting[0])
  108. ale.setRAM(14,0)
  109. ale.setRAM(25,x)
  110. ale.setRAM(14,180)
  111. all_obs = list()
  112. speed_list = list()
  113. first_action_set = False
  114. first_action = 0
  115. for i in range(0,duration):
  116. resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA)
  117. all_obs.append(resized_obs)
  118. if len(all_obs) >= 4:
  119. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  120. action = nn_wrapper.query(stack_tensor)
  121. if not first_action_set:
  122. first_action_set = True
  123. first_action = input_to_action(str(action))
  124. ale.act(input_to_action(str(action)))
  125. else:
  126. ale.act(Action.NOOP)
  127. speed_list.append(ale.getRAM()[14])
  128. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  129. return (Verdict.BAD, first_action)
  130. #time.sleep(0.005)
  131. return (Verdict.INCONCLUSIVE, first_action)
  132. def optimalAction(choices):
  133. return max(choices.items(), key=operator.itemgetter(1))[0]
  134. def computeStateRanking(mdp_file):
  135. logger.info("Computing state ranking")
  136. tic()
  137. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'"
  138. exec(command)
  139. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  140. def fillStateRanking(file_name, match=""):
  141. logger.info("Parsing state ranking")
  142. tic()
  143. state_ranking = dict()
  144. try:
  145. with open(file_name, "r") as f:
  146. file_content = f.readlines()
  147. for line in file_content:
  148. if not "move=0" in line: continue
  149. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  150. if ranking_value <= 0.1:
  151. continue
  152. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  153. #print("stateMapping", stateMapping)
  154. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  155. choices = {key:float(value) for (key,value) in choices.items()}
  156. #print("choices", choices)
  157. #print("ranking_value", ranking_value)
  158. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  159. value = StateValue(ranking_value, choices)
  160. state_ranking[state] = value
  161. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  162. return state_ranking
  163. except EnvironmentError:
  164. print("Ranking file not available. Exiting.")
  165. toc()
  166. sys.exit(1)
  167. except:
  168. toc()
  169. fixed_left_states = list()
  170. fixed_right_states = list()
  171. fixed_noop_states = list()
  172. def populate_fixed_actions(state, action):
  173. if action == Action.LEFT:
  174. fixed_left_states.append(state)
  175. if action == Action.RIGHT:
  176. fixed_right_states.append(state)
  177. if action == Action.NOOP:
  178. fixed_noop_states.append(state)
  179. def update_prism_file(old_prism_file, new_prism_file):
  180. fixed_left_formula = "formula Fixed_Left = false "
  181. fixed_right_formula = "formula Fixed_Right = false "
  182. fixed_noop_formula = "formula Fixed_Noop = false "
  183. for state in fixed_left_states:
  184. fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  185. for state in fixed_right_states:
  186. fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  187. for state in fixed_noop_states:
  188. fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  189. fixed_left_formula += ";\n"
  190. fixed_right_formula += ";\n"
  191. fixed_noop_formula += ";\n"
  192. with open(f'{old_prism_file}', 'r') as file :
  193. filedata = file.read()
  194. if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE)
  195. if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE)
  196. if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE)
  197. with open(f'{new_prism_file}', 'w') as file:
  198. file.write(filedata)
  199. ale = ALEInterface()
  200. #if SDL_SUPPORT:
  201. # ale.setBool("sound", True)
  202. # ale.setBool("display_screen", True)
  203. # Load the ROM file
  204. ale.loadROM(rom_file)
  205. with open('all_positions_v2.pickle', 'rb') as handle:
  206. ramDICT = pickle.load(handle)
  207. y_ram_setting = 60
  208. x = 70
  209. nn_wrapper = SampleFactoryNNQueryWrapper()
  210. iteration = 0
  211. id = int(time.time())
  212. init_mdp = "velocity"
  213. exec(f"mkdir -p images/testing_{id}")
  214. exec(f"cp 1_full_scaled_down.png images/testing_{id}/testing_0000.png")
  215. exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
  216. markerSize = 1
  217. #markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  218. def f(n):
  219. if n >= 1.0:
  220. return True
  221. return False
  222. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  223. markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  224. for state in states:
  225. s = state[0]
  226. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  227. markerList[s.ski_position].append(marker)
  228. for pos, marker in markerList.items():
  229. command = f"convert images/testing_{id}/{target_prefix}_{pos:02}.png {' '.join(marker)} images/testing_{id}/{target_prefix}_{pos:02}.png"
  230. exec(command, verbose=False)
  231. def concatImages(prefix):
  232. exec(f"montage images/testing_{id}/{prefix}_*png -geometry +0+0 -tile x1 images/testing_{id}/{prefix}.png", verbose=False)
  233. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  234. """
  235. Useful to draw a set of states, e.g. a single cluster
  236. """
  237. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  238. logger.info(f"Drawing {len(states)} states onto {target}")
  239. tic()
  240. for state in states:
  241. s = state[0]
  242. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  243. markerList[s.ski_position].append(marker)
  244. for pos, marker in markerList.items():
  245. command = f"convert {source} {' '.join(marker)} images/testing_{id}/{target}_{pos:02}.png"
  246. exec(command, verbose=False)
  247. exec(f"montage images/testing_{id}/{target}_*png -geometry +0+0 -tile x1 images/testing_{id}/{target}.png", verbose=False)
  248. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  249. def drawClusters(clusterDict, target, alpha_factor=1.0):
  250. for ski_position in range(1, num_ski_positions + 1):
  251. source = "images/1_full_scaled_down.png"
  252. exec(f"cp {source} images/testing_{id}/{target}_{ski_position:02}.png")
  253. for _, clusterStates in clusterDict.items():
  254. color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
  255. drawOntoSkiPosImage(clusterStates, color, f"clusters")
  256. concatImages("clusters")
  257. def _init_logger():
  258. logger = logging.getLogger('main')
  259. logger.setLevel(logging.INFO)
  260. handler = logging.StreamHandler(sys.stdout)
  261. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  262. handler.setFormatter(formatter)
  263. logger.addHandler(handler)
  264. def clusterImportantStates(ranking, n_clusters=10):
  265. logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
  266. tic()
  267. states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
  268. kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  269. logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
  270. clusterDict = {i : list() for i in range(0,n_clusters)}
  271. for i, state in enumerate(ranking):
  272. clusterDict[kmeans.labels_[i]].append(state)
  273. drawClusters(clusterDict, f"clusters")
  274. return clusterDict
  275. if __name__ == '__main__':
  276. _init_logger()
  277. logger = logging.getLogger('main')
  278. logger.info("Starting")
  279. while True:
  280. #computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
  281. ranking = fillStateRanking("action_ranking")
  282. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  283. print(type(sorted_ranking))
  284. clusters = clusterImportantStates(sorted_ranking)
  285. sys.exit(1)
  286. #for i, state in enumerate(sorted_ranking):
  287. # print(state)
  288. # if i % 10 == 0:
  289. # input("")
  290. #print(len(sorted_ranking))
  291. """
  292. for important_state in ranking[-100:-1]:
  293. optimal_choice = optimalAction(important_state[1].choices)
  294. #print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}")
  295. x = important_state[0].x
  296. y = important_state[0].y
  297. ski_pos = model_to_actual(important_state[0].ski_position)
  298. result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
  299. #print(f".... {result}")
  300. marker = f"-fill 'rgba({verdict_to_color_map[result[0]],0.7})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '"
  301. markerList[ski_pos].append(marker)
  302. populate_fixed_actions(important_state[0], result[1])
  303. for pos, marker in markerList.items():
  304. command = f"convert images/testing_{id}/testing_0000.png {' '.join(marker)} images/testing_{id}/testing_{iteration+1:03}_{pos:02}.png"
  305. exec(command, verbose=False)
  306. exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False)
  307. iteration += 1
  308. """
  309. update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")