You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

488 lines
19 KiB

6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import subprocess
  5. import re
  6. from collections import defaultdict
  7. from random import randrange
  8. from ale_py import ALEInterface, SDL_SUPPORT, Action
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. import cv2
  12. import pickle
  13. import queue
  14. from dataclasses import dataclass, field
  15. from sklearn.cluster import KMeans, DBSCAN
  16. from enum import Enum
  17. from copy import deepcopy
  18. import numpy as np
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. #import readchar
  22. from sample_factory.algo.utils.tensor_dict import TensorDict
  23. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  24. import time
  25. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  26. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  27. def tic():
  28. import time
  29. global startTime_for_tictoc
  30. startTime_for_tictoc = time.time()
  31. def toc():
  32. import time
  33. if 'startTime_for_tictoc' in globals():
  34. return time.time() - startTime_for_tictoc
  35. class Verdict(Enum):
  36. INCONCLUSIVE = 1
  37. GOOD = 2
  38. BAD = 3
  39. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  40. def convert(tuples):
  41. return dict(tuples)
  42. @dataclass(frozen=True)
  43. class State:
  44. x: int
  45. y: int
  46. ski_position: int
  47. velocity: int
  48. def default_value():
  49. return {'action' : None, 'choiceValue' : None}
  50. @dataclass(frozen=True)
  51. class StateValue:
  52. ranking: float
  53. choices: dict = field(default_factory=default_value)
  54. def exec(command,verbose=True):
  55. if verbose: print(f"Executing {command}")
  56. system(f"echo {command} >> list_of_exec")
  57. return system(command)
  58. num_tests_per_cluster = 50
  59. factor_tests_per_cluster = 0.2
  60. num_ski_positions = 8
  61. num_velocities = 8
  62. def input_to_action(char):
  63. if char == "0":
  64. return Action.NOOP
  65. if char == "1":
  66. return Action.RIGHT
  67. if char == "2":
  68. return Action.LEFT
  69. if char == "3":
  70. return "reset"
  71. if char == "4":
  72. return "set_x"
  73. if char == "5":
  74. return "set_vel"
  75. if char in ["w", "a", "s", "d"]:
  76. return char
  77. def saveObservations(observations, verdict, testDir):
  78. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  79. if len(observations) < 20:
  80. logger.warn(f"Potentially spurious test case for {testDir}")
  81. testDir = f"{testDir}_pot_spurious"
  82. exec(f"mkdir {testDir}", verbose=False)
  83. for i, obs in enumerate(observations):
  84. img = Image.fromarray(obs)
  85. img.save(f"{testDir}/{i:003}.png")
  86. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  87. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  88. #def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
  89. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  90. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  91. #testDir = f"{x}_{y}_{ski_position}"
  92. for i, r in enumerate(ramDICT[y]):
  93. ale.setRAM(i,r)
  94. ski_position_setting = ski_position_counter[ski_position]
  95. for i in range(0,ski_position_setting[1]):
  96. ale.act(ski_position_setting[0])
  97. ale.setRAM(14,0)
  98. ale.setRAM(25,x)
  99. ale.setRAM(14,180) # TODO
  100. all_obs = list()
  101. speed_list = list()
  102. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  103. for i in range(0,4):
  104. all_obs.append(resized_obs)
  105. for i in range(0,duration-4):
  106. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  107. all_obs.append(resized_obs)
  108. if i % 4 == 0:
  109. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  110. action = nn_wrapper.query(stack_tensor)
  111. ale.act(input_to_action(str(action)))
  112. else:
  113. ale.act(input_to_action(str(action)))
  114. speed_list.append(ale.getRAM()[14])
  115. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  116. saveObservations(all_obs, Verdict.BAD, testDir)
  117. return Verdict.BAD
  118. saveObservations(all_obs, Verdict.GOOD, testDir)
  119. return Verdict.GOOD
  120. def computeStateRanking(mdp_file, iteration):
  121. logger.info("Computing state ranking")
  122. tic()
  123. try:
  124. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop 'Rmax=? [C <= 1000]'"
  125. result = subprocess.run(command, shell=True, check=True)
  126. print(result)
  127. except Exception as e:
  128. print(e)
  129. sys.exit(-1)
  130. exec(f"mv action_ranking action_ranking_{iteration:03}")
  131. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  132. def fillStateRanking(file_name, match=""):
  133. logger.info(f"Parsing state ranking, {file_name}")
  134. tic()
  135. state_ranking = dict()
  136. try:
  137. with open(file_name, "r") as f:
  138. file_content = f.readlines()
  139. for line in file_content:
  140. if not "move=0" in line: continue
  141. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  142. if ranking_value <= 0.1:
  143. continue
  144. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  145. #print("stateMapping", stateMapping)
  146. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  147. choices = {key:float(value) for (key,value) in choices.items()}
  148. #print("choices", choices)
  149. #print("ranking_value", ranking_value)
  150. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  151. #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  152. value = StateValue(ranking_value, choices)
  153. state_ranking[state] = value
  154. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  155. return state_ranking
  156. except EnvironmentError:
  157. print("Ranking file not available. Exiting.")
  158. toc()
  159. sys.exit(-1)
  160. except:
  161. toc()
  162. def createDisjunction(formulas):
  163. return " | ".join(formulas)
  164. def clusterFormula(cluster):
  165. if len(cluster) == 0: return
  166. formulas = list()
  167. for state in cluster:
  168. formulas.append(f"(x={state[0].x} & y={state[0].y} & velocity={state[0].velocity} & ski_position={state[0].ski_position})")
  169. while len(formulas) > 1:
  170. formulas_tmp = [f"({formulas[i]} | {formulas[i+1]})" for i in range(0,len(formulas)//2)]
  171. if len(formulas) % 2 == 1:
  172. formulas_tmp.append(formulas[-1])
  173. formulas = formulas_tmp
  174. return "(" + formulas[0] + ")"
  175. def clusterFormulaXY(cluster):
  176. if len(cluster) == 0: return
  177. formulas = set()
  178. for state in cluster:
  179. formulas.add(f"(x={state[0].x} & y={state[0].y})")
  180. formulas = list(formulas)
  181. while len(formulas) > 1:
  182. formulas_tmp = [f"({formulas[i]} | {formulas[i+1]})" for i in range(0,len(formulas)//2)]
  183. if len(formulas) % 2 == 1:
  184. formulas_tmp.append(formulas[-1])
  185. formulas = formulas_tmp
  186. return "(" + formulas[0] + ")"
  187. def clusterFormulaTrimmed(cluster):
  188. formula = ""
  189. states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]
  190. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  191. skiPositionGroup = defaultdict(list)
  192. for item in states:
  193. skiPositionGroup[item[2]].append(item)
  194. #todo add velocity here
  195. firstVelocity = True
  196. for skiPosition, skiPos_group in skiPositionGroup.items():
  197. formula += f"ski_position={skiPosition} & "
  198. velocityGroup = defaultdict(list)
  199. for item in skiPos_group:
  200. velocityGroup[item[3]].append(item)
  201. for velocity, velocity_group in velocityGroup.items():
  202. if firstVelocity:
  203. firstVelocity = False
  204. else:
  205. formula += " | "
  206. formula += f" (velocity={velocity} & "
  207. firstY = True
  208. yPosGroup = defaultdict(list)
  209. for item in velocity_group:
  210. yPosGroup[item[1]].append(item)
  211. for y, y_group in yPosGroup.items():
  212. if firstY:
  213. firstY = False
  214. else:
  215. formula += " | "
  216. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  217. formula += f"( y={y} & ("
  218. current_x_min = sorted_y_group[0][0]
  219. current_x = sorted_y_group[0][0]
  220. x_ranges = list()
  221. for state in sorted_y_group[1:-1]:
  222. if state[0] - current_x == 1:
  223. current_x = state[0]
  224. else:
  225. x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
  226. current_x_min = state[0]
  227. current_x = state[0]
  228. x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
  229. formula += " | ".join(x_ranges)
  230. formula += ") )"
  231. formula += ")"
  232. return formula
  233. def createBalancedDisjunction(indices, name):
  234. #logger.info(f"Creating balanced disjunction for {len(indices)} ({indices}) formulas")
  235. if len(indices) == 0:
  236. return f"formula {name} = false;\n"
  237. else:
  238. while len(indices) > 1:
  239. indices_tmp = [f"({indices[i]} | {indices[i+1]})" for i in range(0,len(indices)//2)]
  240. if len(indices) % 2 == 1:
  241. indices_tmp.append(indices[-1])
  242. indices = indices_tmp
  243. disjunction = f"formula {name} = " + " ".join(indices) + ";\n"
  244. return disjunction
  245. def createUnsafeFormula(clusters):
  246. label = "label \"Unsafe\" = Unsafe;\n"
  247. formulas = ""
  248. indices = list()
  249. for i, cluster in enumerate(clusters):
  250. formulas += f"formula Unsafe_{i} = {clusterFormulaXY(cluster)};\n"
  251. indices.append(f"Unsafe_{i}")
  252. return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe") + label
  253. def createSafeFormula(clusters):
  254. label = "label \"Safe\" = Safe;\n"
  255. formulas = ""
  256. indices = list()
  257. for i, cluster in enumerate(clusters):
  258. formulas += f"formula Safe_{i} = {clusterFormulaXY(cluster)};\n"
  259. indices.append(f"Safe_{i}")
  260. return formulas + "\n" + createBalancedDisjunction(indices, "Safe") + label
  261. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  262. logger.info("Creating next prism file")
  263. tic()
  264. initFile = f"{newFile}_no_formulas.prism"
  265. newFile = f"{newFile}_{iteration:03}.prism"
  266. exec(f"cp {initFile} {newFile}", verbose=False)
  267. with open(newFile, "a") as prism:
  268. prism.write(createSafeFormula(safeStates))
  269. prism.write(createUnsafeFormula(unsafeStates))
  270. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  271. ale = ALEInterface()
  272. #if SDL_SUPPORT:
  273. # ale.setBool("sound", True)
  274. # ale.setBool("display_screen", True)
  275. # Load the ROM file
  276. ale.loadROM(rom_file)
  277. with open('all_positions_v2.pickle', 'rb') as handle:
  278. ramDICT = pickle.load(handle)
  279. y_ram_setting = 60
  280. x = 70
  281. nn_wrapper = SampleFactoryNNQueryWrapper()
  282. experiment_id = int(time.time())
  283. init_mdp = "safety"
  284. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  285. markerSize = 1
  286. imagesDir = f"images/testing_{experiment_id}"
  287. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  288. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  289. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities + 1) for ski_position in range(1,num_ski_positions + 1)}
  290. for state in states:
  291. s = state[0]
  292. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  293. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
  294. markerList[(s.ski_position, s.velocity)].append(marker)
  295. for (pos, vel), marker in markerList.items():
  296. command = f"convert {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png"
  297. exec(command, verbose=False)
  298. def concatImages(prefix, iteration):
  299. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities+1) for pos in range(1,num_ski_positions+1) ]
  300. for vel in range(0, num_velocities + 1):
  301. for pos in range(1, num_ski_positions + 1):
  302. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  303. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  304. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  305. exec(command, verbose=False)
  306. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  307. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  308. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  309. """
  310. Useful to draw a set of states, e.g. a single cluster
  311. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  312. logger.info(f"Drawing {len(states)} states onto {target}")
  313. tic()
  314. for state in states:
  315. s = state[0]
  316. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  317. markerList[s.ski_position].append(marker)
  318. for pos, marker in markerList.items():
  319. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  320. exec(command, verbose=False)
  321. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  322. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  323. """
  324. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  325. logger.info(f"Drawing clusters")
  326. tic()
  327. for velocity in range(0, num_velocities + 1):
  328. for ski_position in range(1, num_ski_positions + 1):
  329. source = "images/1_full_scaled_down.png"
  330. exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  331. for _, clusterStates in clusterDict.items():
  332. color = f"{np.random.choice(range(256))}, {np.random.choice(range(256))}, {np.random.choice(range(256))}"
  333. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  334. concatImages(target, iteration)
  335. logger.info(f"Drawing clusters - DONE: took {toc()} seconds")
  336. def drawResult(clusterDict, target, iteration):
  337. logger.info(f"Drawing clusters")
  338. for velocity in range(0,num_velocities+1):
  339. for ski_position in range(1, num_ski_positions + 1):
  340. source = "images/1_full_scaled_down.png"
  341. exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  342. for _, (clusterStates, result) in clusterDict.items():
  343. color = "100,100,100"
  344. if result == Verdict.GOOD:
  345. color = "0,200,0"
  346. elif result == Verdict.BAD:
  347. color = "200,0,0"
  348. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  349. concatImages(target, iteration)
  350. logger.info(f"Drawing clusters - DONE: took {toc()} seconds")
  351. def _init_logger():
  352. logger = logging.getLogger('main')
  353. logger.setLevel(logging.INFO)
  354. handler = logging.StreamHandler(sys.stdout)
  355. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  356. handler.setFormatter(formatter)
  357. logger.addHandler(handler)
  358. def clusterImportantStates(ranking, iteration):
  359. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  360. tic()
  361. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  362. #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
  363. #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  364. dbscan = DBSCAN(eps=15).fit(states)
  365. labels = dbscan.labels_
  366. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  367. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  368. clusterDict = {i : list() for i in range(0,n_clusters)}
  369. for i, state in enumerate(ranking):
  370. if labels[i] == -1: continue
  371. clusterDict[labels[i]].append(state)
  372. drawClusters(clusterDict, f"clusters", iteration)
  373. return clusterDict
  374. if __name__ == '__main__':
  375. _init_logger()
  376. logger = logging.getLogger('main')
  377. logger.info("Starting")
  378. n_clusters = 40
  379. testAll = False
  380. safeStates = list()
  381. unsafeStates = list()
  382. iteration = 0
  383. while True:
  384. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  385. computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  386. ranking = fillStateRanking(f"action_ranking_{iteration:03}")
  387. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  388. clusters = clusterImportantStates(sorted_ranking, iteration)
  389. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  390. clusterResult = dict()
  391. for id, cluster in clusters.items():
  392. num_tests = int(factor_tests_per_cluster * len(cluster))
  393. num_tests = 1
  394. #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  395. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  396. randomStates = [cluster[i] for i in randomStates]
  397. verdictGood = True
  398. for state in randomStates:
  399. x = state[0].x
  400. y = state[0].y
  401. ski_pos = state[0].ski_position
  402. velocity = state[0].velocity
  403. #result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
  404. result = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  405. result = Verdict.BAD # TODO REMOVE ME!!!!!!!!!!!!!!
  406. if result == Verdict.BAD:
  407. if testAll:
  408. failingPerCluster[id].append(state)
  409. else:
  410. clusterResult[id] = (cluster, Verdict.BAD)
  411. verdictGood = False
  412. unsafeStates.append(cluster)
  413. break
  414. if verdictGood:
  415. clusterResult[id] = (cluster, Verdict.GOOD)
  416. safeStates.append(cluster)
  417. logger.info(f"Iteration: {iteration:03} -\tSafe Results : {sum([len(c) for c in safeStates])} -\tUnsafe Results:{sum([len(c) for c in unsafeStates])}")
  418. if testAll: drawClusters(failingPerCluster, f"failing", iteration)
  419. #drawResult(clusterResult, "result", iteration)
  420. iteration += 1