You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

408 lines
15 KiB

  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import re
  5. from collections import defaultdict
  6. from random import randrange
  7. from ale_py import ALEInterface, SDL_SUPPORT, Action
  8. from PIL import Image
  9. from matplotlib import pyplot as plt
  10. import cv2
  11. import pickle
  12. import queue
  13. from dataclasses import dataclass, field
  14. from sklearn.cluster import KMeans, DBSCAN
  15. from enum import Enum
  16. from copy import deepcopy
  17. import numpy as np
  18. import logging
  19. logger = logging.getLogger(__name__)
  20. #import readchar
  21. from sample_factory.algo.utils.tensor_dict import TensorDict
  22. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  23. import time
  24. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  25. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  26. def tic():
  27. import time
  28. global startTime_for_tictoc
  29. startTime_for_tictoc = time.time()
  30. def toc():
  31. import time
  32. if 'startTime_for_tictoc' in globals():
  33. return time.time() - startTime_for_tictoc
  34. class Verdict(Enum):
  35. INCONCLUSIVE = 1
  36. GOOD = 2
  37. BAD = 3
  38. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  39. def convert(tuples):
  40. return dict(tuples)
  41. @dataclass(frozen=True)
  42. class State:
  43. x: int
  44. y: int
  45. ski_position: int
  46. #velocity: int
  47. def default_value():
  48. return {'action' : None, 'choiceValue' : None}
  49. @dataclass(frozen=True)
  50. class StateValue:
  51. ranking: float
  52. choices: dict = field(default_factory=default_value)
  53. def exec(command,verbose=True):
  54. if verbose: print(f"Executing {command}")
  55. system(f"echo {command} >> list_of_exec")
  56. return system(command)
  57. num_tests_per_cluster = 50
  58. factor_tests_per_cluster = 0.2
  59. num_ski_positions = 8
  60. def input_to_action(char):
  61. if char == "0":
  62. return Action.NOOP
  63. if char == "1":
  64. return Action.RIGHT
  65. if char == "2":
  66. return Action.LEFT
  67. if char == "3":
  68. return "reset"
  69. if char == "4":
  70. return "set_x"
  71. if char == "5":
  72. return "set_vel"
  73. if char in ["w", "a", "s", "d"]:
  74. return char
  75. def saveObservations(observations, verdict, testDir):
  76. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  77. if len(observations) < 20:
  78. logger.warn(f"Potentially spurious test case for {testDir}")
  79. testDir = f"{testDir}_pot_spurious"
  80. exec(f"mkdir {testDir}", verbose=False)
  81. for i, obs in enumerate(observations):
  82. img = Image.fromarray(obs)
  83. img.save(f"{testDir}/{i:003}.png")
  84. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  85. #def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  86. def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
  87. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  88. testDir = f"{x}_{y}_{ski_position}"#_{velocity}"
  89. for i, r in enumerate(ramDICT[y]):
  90. ale.setRAM(i,r)
  91. ski_position_setting = ski_position_counter[ski_position]
  92. for i in range(0,ski_position_setting[1]):
  93. ale.act(ski_position_setting[0])
  94. ale.setRAM(14,0)
  95. ale.setRAM(25,x)
  96. ale.setRAM(14,180) # TODO
  97. all_obs = list()
  98. speed_list = list()
  99. first_action_set = False
  100. first_action = 0
  101. for i in range(0,duration):
  102. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  103. if len(all_obs) >= 4:
  104. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  105. action = nn_wrapper.query(stack_tensor)
  106. if not first_action_set:
  107. first_action_set = True
  108. first_action = input_to_action(str(action))
  109. ale.act(input_to_action(str(action)))
  110. else:
  111. ale.act(Action.NOOP)
  112. speed_list.append(ale.getRAM()[14])
  113. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  114. saveObservations(all_obs, Verdict.BAD, testDir)
  115. return Verdict.BAD
  116. saveObservations(all_obs, Verdict.GOOD, testDir)
  117. return Verdict.GOOD
  118. def optimalAction(choices):
  119. return max(choices.items(), key=operator.itemgetter(1))[0]
  120. def computeStateRanking(mdp_file):
  121. logger.info("Computing state ranking")
  122. tic()
  123. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop 'Rmax=? [C <= 1000]'"
  124. exec(command)
  125. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  126. def fillStateRanking(file_name, match=""):
  127. logger.info("Parsing state ranking")
  128. tic()
  129. state_ranking = dict()
  130. try:
  131. with open(file_name, "r") as f:
  132. file_content = f.readlines()
  133. for line in file_content:
  134. if not "move=0" in line: continue
  135. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  136. if ranking_value <= 0.1:
  137. continue
  138. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  139. #print("stateMapping", stateMapping)
  140. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  141. choices = {key:float(value) for (key,value) in choices.items()}
  142. #print("choices", choices)
  143. #print("ranking_value", ranking_value)
  144. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))#, int(stateMapping["velocity"]))
  145. value = StateValue(ranking_value, choices)
  146. state_ranking[state] = value
  147. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  148. return state_ranking
  149. except EnvironmentError:
  150. print("Ranking file not available. Exiting.")
  151. toc()
  152. sys.exit(1)
  153. except:
  154. toc()
  155. def createDisjunction(formulas):
  156. return " | ".join(formulas)
  157. def clusterFormula(cluster):
  158. formula = ""
  159. #states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]
  160. states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  161. skiPositionGroup = defaultdict(list)
  162. for item in states:
  163. skiPositionGroup[item[2]].append(item)
  164. first = True
  165. #todo add velocity here
  166. for skiPosition, group in skiPositionGroup.items():
  167. formula += f"ski_position={skiPosition} & ("
  168. yPosGroup = defaultdict(list)
  169. for item in group:
  170. yPosGroup[item[1]].append(item)
  171. for y, y_group in yPosGroup.items():
  172. if first:
  173. first = False
  174. else:
  175. formula += " | "
  176. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  177. formula += f"( y={y} & ("
  178. current_x_min = sorted_y_group[0][0]
  179. current_x = sorted_y_group[0][0]
  180. x_ranges = list()
  181. for state in sorted_y_group[1:-1]:
  182. if state[0] - current_x == 1:
  183. current_x = state[0]
  184. else:
  185. x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
  186. current_x_min = state[0]
  187. current_x = state[0]
  188. x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
  189. formula += " | ".join(x_ranges)
  190. formula += ") )"
  191. formula += ")"
  192. return formula
  193. def createUnsafeFormula(clusters):
  194. formulas = ""
  195. disjunction = "formula Unsafe = false"
  196. for i, cluster in enumerate(clusters):
  197. formulas += f"formula Unsafe_{i} = {clusterFormula(cluster)};\n"
  198. clusterFormula(cluster)
  199. disjunction += f" | Unsafe_{i}"
  200. disjunction += ";\n"
  201. label = "label \"Unsafe\" = Unsafe;\n"
  202. return formulas + "\n" + disjunction + label
  203. def createSafeFormula(clusters):
  204. formulas = ""
  205. disjunction = "formula Safe = false"
  206. for i, cluster in enumerate(clusters):
  207. formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
  208. disjunction += f" | Safe_{i}"
  209. disjunction += ";\n"
  210. label = "label \"Safe\" = Safe;\n"
  211. return formulas + "\n" + disjunction + label
  212. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  213. logger.info("Creating next prism file")
  214. tic()
  215. initFile = f"{newFile}_no_formulas.prism"
  216. newFile = f"{newFile}_{iteration:03}.prism"
  217. exec(f"cp {initFile} {newFile}", verbose=False)
  218. with open(newFile, "a") as prism:
  219. prism.write(createSafeFormula(safeStates))
  220. prism.write(createUnsafeFormula(unsafeStates))
  221. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  222. ale = ALEInterface()
  223. #if SDL_SUPPORT:
  224. # ale.setBool("sound", True)
  225. # ale.setBool("display_screen", True)
  226. # Load the ROM file
  227. ale.loadROM(rom_file)
  228. with open('all_positions_v2.pickle', 'rb') as handle:
  229. ramDICT = pickle.load(handle)
  230. y_ram_setting = 60
  231. x = 70
  232. nn_wrapper = SampleFactoryNNQueryWrapper()
  233. experiment_id = int(time.time())
  234. init_mdp = "velocity_safety"
  235. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  236. markerSize = 1
  237. imagesDir = f"images/testing_{experiment_id}"
  238. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  239. markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  240. for state in states:
  241. s = state[0]
  242. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  243. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
  244. markerList[s.ski_position].append(marker)
  245. for pos, marker in markerList.items():
  246. command = f"convert {imagesDir}/{target_prefix}_{pos:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_individual.png"
  247. exec(command, verbose=False)
  248. def concatImages(prefix, iteration):
  249. exec(f"montage {imagesDir}/{prefix}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  250. exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  251. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  252. """
  253. Useful to draw a set of states, e.g. a single cluster
  254. """
  255. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  256. logger.info(f"Drawing {len(states)} states onto {target}")
  257. tic()
  258. for state in states:
  259. s = state[0]
  260. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  261. markerList[s.ski_position].append(marker)
  262. for pos, marker in markerList.items():
  263. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  264. exec(command, verbose=False)
  265. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  266. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  267. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  268. for ski_position in range(1, num_ski_positions + 1):
  269. source = "images/1_full_scaled_down.png"
  270. exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_individual.png", verbose=False)
  271. for _, clusterStates in clusterDict.items():
  272. color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
  273. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  274. concatImages(target, iteration)
  275. def drawResult(clusterDict, target, iteration):
  276. for ski_position in range(1, num_ski_positions + 1):
  277. source = "images/1_full_scaled_down.png"
  278. exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_individual.png")
  279. for _, (clusterStates, result) in clusterDict.items():
  280. color = "100,100,100"
  281. if result == Verdict.GOOD:
  282. color = "0,200,0"
  283. elif result == Verdict.BAD:
  284. color = "200,0,0"
  285. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  286. concatImages(target, iteration)
  287. def _init_logger():
  288. logger = logging.getLogger('main')
  289. logger.setLevel(logging.INFO)
  290. handler = logging.StreamHandler(sys.stdout)
  291. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  292. handler.setFormatter(formatter)
  293. logger.addHandler(handler)
  294. def clusterImportantStates(ranking, iteration, n_clusters=40):
  295. logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster")
  296. tic()
  297. #states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[0].velocity * 10, s[1].ranking] for s in ranking]
  298. states = [[s[0].x,s[0].y, s[0].ski_position * 10, s[1].ranking] for s in ranking]
  299. kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  300. #dbscan = DBSCAN().fit(states)
  301. logger.info(f"Starting to cluster {len(ranking)} states into {n_clusters} cluster - DONE: took {toc()} seconds")
  302. clusterDict = {i : list() for i in range(0,n_clusters)}
  303. for i, state in enumerate(ranking):
  304. clusterDict[kmeans.labels_[i]].append(state)
  305. drawClusters(clusterDict, f"clusters", iteration)
  306. return clusterDict
  307. if __name__ == '__main__':
  308. _init_logger()
  309. logger = logging.getLogger('main')
  310. logger.info("Starting")
  311. n_clusters = 40
  312. testAll = False
  313. safeStates = list()
  314. unsafeStates = list()
  315. iteration = 0
  316. while True:
  317. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  318. computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
  319. ranking = fillStateRanking("action_ranking")
  320. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  321. clusters = clusterImportantStates(sorted_ranking, iteration, n_clusters)
  322. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  323. clusterResult = dict()
  324. for id, cluster in clusters.items():
  325. num_tests = int(factor_tests_per_cluster * len(cluster))
  326. logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  327. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  328. randomStates = [cluster[i] for i in randomStates]
  329. verdictGood = True
  330. for state in randomStates:
  331. x = state[0].x
  332. y = state[0].y
  333. ski_pos = state[0].ski_position
  334. #velocity = state[0].velocity
  335. result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
  336. if result == Verdict.BAD:
  337. if testAll:
  338. failingPerCluster[id].append(state)
  339. else:
  340. clusterResult[id] = (cluster, Verdict.BAD)
  341. verdictGood = False
  342. unsafeStates.append(cluster)
  343. break
  344. if verdictGood:
  345. clusterResult[id] = (cluster, Verdict.GOOD)
  346. safeStates.append(cluster)
  347. if testAll: drawClusters(failingPerCluster, f"failing")
  348. drawResult(clusterResult, "result", iteration)
  349. iteration += 1