You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

572 lines
25 KiB

6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import subprocess
  5. import re
  6. from collections import defaultdict
  7. from random import randrange
  8. from ale_py import ALEInterface, SDL_SUPPORT, Action
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. import cv2
  12. import pickle
  13. import queue
  14. from dataclasses import dataclass, field
  15. from sklearn.cluster import KMeans, DBSCAN
  16. from enum import Enum
  17. from copy import deepcopy
  18. import numpy as np
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. #import readchar
  22. from sample_factory.algo.utils.tensor_dict import TensorDict
  23. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  24. import time
  25. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  26. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  27. def tic():
  28. import time
  29. global startTime_for_tictoc
  30. startTime_for_tictoc = time.time()
  31. def toc():
  32. import time
  33. if 'startTime_for_tictoc' in globals():
  34. return time.time() - startTime_for_tictoc
  35. class Verdict(Enum):
  36. INCONCLUSIVE = 1
  37. GOOD = 2
  38. BAD = 3
  39. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  40. def convert(tuples):
  41. return dict(tuples)
  42. @dataclass(frozen=True)
  43. class State:
  44. x: int
  45. y: int
  46. ski_position: int
  47. velocity: int
  48. def default_value():
  49. return {'action' : None, 'choiceValue' : None}
  50. @dataclass(frozen=True)
  51. class StateValue:
  52. ranking: float
  53. choices: dict = field(default_factory=default_value)
  54. @dataclass(frozen=False)
  55. class TestResult:
  56. init_check_pes_min: float
  57. init_check_pes_max: float
  58. init_check_pes_avg: float
  59. init_check_opt_min: float
  60. init_check_opt_max: float
  61. init_check_opt_avg: float
  62. safe_states: int
  63. unsafe_states: int
  64. policy_queries: int
  65. def __str__(self):
  66. return f"""Test Result:
  67. init_check_pes_min: {self.init_check_pes_min}
  68. init_check_pes_max: {self.init_check_pes_max}
  69. init_check_pes_avg: {self.init_check_pes_avg}
  70. init_check_opt_min: {self.init_check_opt_min}
  71. init_check_opt_max: {self.init_check_opt_max}
  72. init_check_opt_avg: {self.init_check_opt_avg}
  73. """
  74. def csv(self, ws=" "):
  75. return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}{self.safe_states}{ws}{self.unsafe_states}{ws}{self.policy_queries}"
  76. def exec(command,verbose=True):
  77. if verbose: print(f"Executing {command}")
  78. system(f"echo {command} >> list_of_exec")
  79. return system(command)
  80. num_tests_per_cluster = 50
  81. factor_tests_per_cluster = 0.2
  82. num_ski_positions = 8
  83. num_velocities = 5
  84. def input_to_action(char):
  85. if char == "0":
  86. return Action.NOOP
  87. if char == "1":
  88. return Action.RIGHT
  89. if char == "2":
  90. return Action.LEFT
  91. if char == "3":
  92. return "reset"
  93. if char == "4":
  94. return "set_x"
  95. if char == "5":
  96. return "set_vel"
  97. if char in ["w", "a", "s", "d"]:
  98. return char
  99. def saveObservations(observations, verdict, testDir):
  100. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  101. if len(observations) < 20:
  102. logger.warn(f"Potentially spurious test case for {testDir}")
  103. testDir = f"{testDir}_pot_spurious"
  104. exec(f"mkdir {testDir}", verbose=False)
  105. for i, obs in enumerate(observations):
  106. img = Image.fromarray(obs)
  107. img.save(f"{testDir}/{i:003}.png")
  108. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  109. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  110. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  111. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  112. try:
  113. for i, r in enumerate(ramDICT[y]):
  114. ale.setRAM(i,r)
  115. ski_position_setting = ski_position_counter[ski_position]
  116. for i in range(0,ski_position_setting[1]):
  117. ale.act(ski_position_setting[0])
  118. ale.setRAM(14,0)
  119. ale.setRAM(25,x)
  120. ale.setRAM(14,180) # TODO
  121. except Exception as e:
  122. print(e)
  123. logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
  124. return (Verdict.INCONCLUSIVE, 0)
  125. num_queries = 0
  126. all_obs = list()
  127. speed_list = list()
  128. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  129. for i in range(0,4):
  130. all_obs.append(resized_obs)
  131. for i in range(0,duration-4):
  132. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  133. all_obs.append(resized_obs)
  134. if i % 4 == 0:
  135. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  136. action = nn_wrapper.query(stack_tensor)
  137. num_queries += 1
  138. ale.act(input_to_action(str(action)))
  139. else:
  140. ale.act(input_to_action(str(action)))
  141. speed_list.append(ale.getRAM()[14])
  142. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  143. #saveObservations(all_obs, Verdict.BAD, testDir)
  144. return (Verdict.BAD, num_queries)
  145. #saveObservations(all_obs, Verdict.GOOD, testDir)
  146. return (Verdict.GOOD, num_queries)
  147. def skiPositionFormulaList(name):
  148. formulas = list()
  149. for i in range(1, num_ski_positions+1):
  150. formulas.append(f"\"{name}_{i}\"")
  151. return createBalancedDisjunction(formulas)
  152. def computeStateRanking(mdp_file, iteration):
  153. logger.info("Computing state ranking")
  154. tic()
  155. prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  156. prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  157. prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  158. prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  159. prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  160. prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  161. prop += 'Rmax=? [C <= 200]'
  162. results = list()
  163. try:
  164. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
  165. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  166. num_states = 0
  167. for line in output:
  168. #print(line)
  169. if "States:" in line:
  170. num_states = int(line.split(" ")[-1])
  171. if "Result" in line and not len(results) >= 6:
  172. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  173. if range_value:
  174. results.append(float(range_value.group(2)))
  175. results.append(float(range_value.group(3)))
  176. else:
  177. value = re.search(r"(.*:)(.*)", line)
  178. results.append(float(value.group(2)))
  179. exec(f"mv action_ranking action_ranking_{iteration:03}")
  180. except subprocess.CalledProcessError as e:
  181. # todo die gracefully if ranking is uniform
  182. print(e.output)
  183. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  184. return TestResult(*tuple(results),0,0,0), num_states
  185. def fillStateRanking(file_name, match=""):
  186. logger.info(f"Parsing state ranking, {file_name}")
  187. tic()
  188. state_ranking = dict()
  189. try:
  190. with open(file_name, "r") as f:
  191. file_content = f.readlines()
  192. for line in file_content:
  193. if not "move=0" in line: continue
  194. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  195. if ranking_value <= 0.1:
  196. continue
  197. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  198. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  199. choices = {key:float(value) for (key,value) in choices.items()}
  200. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  201. value = StateValue(ranking_value, choices)
  202. state_ranking[state] = value
  203. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  204. return state_ranking
  205. except EnvironmentError:
  206. print("Ranking file not available. Exiting.")
  207. toc()
  208. sys.exit(-1)
  209. except:
  210. toc()
  211. def createDisjunction(formulas):
  212. return " | ".join(formulas)
  213. def statesFormulaTrimmed(states, name):
  214. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  215. skiPositionGroup = defaultdict(list)
  216. for item in states:
  217. skiPositionGroup[item[2]].append(item)
  218. formulas = list()
  219. for skiPosition, skiPos_group in skiPositionGroup.items():
  220. formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & "
  221. #print(f"{name} ski_pos:{skiPosition}")
  222. velocityGroup = defaultdict(list)
  223. velocityFormulas = list()
  224. for item in skiPos_group:
  225. velocityGroup[item[3]].append(item)
  226. for velocity, velocity_group in velocityGroup.items():
  227. #print(f"\tvel:{velocity}")
  228. formulasPerSkiPosition = list()
  229. yPosGroup = defaultdict(list)
  230. yFormulas = list()
  231. for item in velocity_group:
  232. yPosGroup[item[1]].append(item)
  233. for y, y_group in yPosGroup.items():
  234. #print(f"\t\ty:{y}")
  235. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  236. current_x_min = sorted_y_group[0][0]
  237. current_x = sorted_y_group[0][0]
  238. x_ranges = list()
  239. for state in sorted_y_group[1:-1]:
  240. if state[0] - current_x == 1:
  241. current_x = state[0]
  242. else:
  243. x_ranges.append(f" ({current_x_min}<=x&x<={current_x})")
  244. current_x_min = state[0]
  245. current_x = state[0]
  246. x_ranges.append(f" {current_x_min}<=x&x<={sorted_y_group[-1][0]}")
  247. yFormulas.append(f" (y={y} & {createBalancedDisjunction(x_ranges)})")
  248. #x_ranges.clear()
  249. #velocityFormulas.append(f"(velocity={velocity} & {createBalancedDisjunction(yFormulas)})")
  250. velocityFormulas.append(f"({createBalancedDisjunction(yFormulas)})")
  251. #yFormulas.clear()
  252. formula += createBalancedDisjunction(velocityFormulas) + ");"
  253. #velocityFormulas.clear()
  254. formulas.append(formula)
  255. for i in range(1, num_ski_positions+1):
  256. if i in skiPositionGroup:
  257. continue
  258. formulas.append(f"formula {name}_{i} = false;")
  259. return "\n".join(formulas) + "\n"
  260. # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
  261. def pairwise(iterable):
  262. "s -> (s0, s1), (s2, s3), (s4, s5), ..."
  263. a = iter(iterable)
  264. return zip(a, a)
  265. def createBalancedDisjunction(formulas):
  266. if len(formulas) == 0:
  267. return "false"
  268. while len(formulas) > 1:
  269. formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
  270. if len(formulas) % 2 == 1:
  271. formulas_tmp.append(formulas[-1])
  272. formulas = formulas_tmp
  273. return " ".join(formulas)
  274. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  275. logger.info("Creating next prism file")
  276. tic()
  277. initFile = f"{newFile}_no_formulas.prism"
  278. newFile = f"{newFile}_{iteration:03}.prism"
  279. exec(f"cp {initFile} {newFile}", verbose=False)
  280. with open(newFile, "a") as prism:
  281. prism.write(statesFormulaTrimmed(safeStates, "Safe"))
  282. prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe"))
  283. for i in range(1,num_ski_positions+1):
  284. prism.write(f"label \"Safe_{i}\" = Safe_{i};\n")
  285. prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n")
  286. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  287. ale = ALEInterface()
  288. #if SDL_SUPPORT:
  289. # ale.setBool("sound", True)
  290. # ale.setBool("display_screen", True)
  291. # Load the ROM file
  292. ale.loadROM(rom_file)
  293. with open('all_positions_v2.pickle', 'rb') as handle:
  294. ramDICT = pickle.load(handle)
  295. y_ram_setting = 60
  296. x = 70
  297. nn_wrapper = SampleFactoryNNQueryWrapper()
  298. experiment_id = int(time.time())
  299. init_mdp = "velocity_safety"
  300. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  301. markerSize = 1
  302. imagesDir = f"images/testing_{experiment_id}"
  303. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0, markerSize=1, drawCircle=False):
  304. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  305. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
  306. images = dict()
  307. mergedImages = dict()
  308. for ski_position in range(1, num_ski_positions + 1):
  309. for velocity in range(0,num_velocities):
  310. images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
  311. mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
  312. for state in states:
  313. s = state[0]
  314. marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
  315. markerList[(s.ski_position, s.velocity)].append(marker)
  316. for (pos, vel), marker in markerList.items():
  317. if len(marker) == 0: continue
  318. if drawCircle:
  319. for m in marker:
  320. images[(pos,vel)] = cv2.circle(images[(pos,vel)], m[2], 1, m[0], thickness=-1)
  321. mergedImages[pos] = cv2.circle(mergedImages[pos], m[2], 1, m[0], thickness=-1)
  322. else:
  323. for m in marker:
  324. images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
  325. mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
  326. for (ski_position, velocity), image in images.items():
  327. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
  328. for ski_position, image in mergedImages.items():
  329. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
  330. def concatImages(prefix, iteration):
  331. logger.info(f"Concatenating images")
  332. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
  333. mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
  334. for vel in range(0, num_velocities):
  335. for pos in range(1, num_ski_positions + 1):
  336. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  337. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  338. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  339. exec(command, verbose=False)
  340. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  341. exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}_merged.png", verbose=False)
  342. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  343. logger.info(f"Concatenating images - DONE")
  344. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  345. """
  346. Useful to draw a set of states, e.g. a single cluster
  347. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  348. logger.info(f"Drawing {len(states)} states onto {target}")
  349. tic()
  350. for state in states:
  351. s = state[0]
  352. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  353. markerList[s.ski_position].append(marker)
  354. for pos, marker in markerList.items():
  355. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  356. exec(command, verbose=False)
  357. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  358. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  359. """
  360. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0): # TODO do not draw already drawn clusters
  361. logger.info(f"Drawing {len(clusterDict)} clusters")
  362. tic()
  363. for _, clusterStates in clusterDict.items():
  364. color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
  365. color = (int(color[0]), int(color[1]), int(color[2]))
  366. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  367. concatImages(target, iteration)
  368. logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
  369. def drawResult(clusterDict, target, iteration): # TODO do not draw already drawn clusters
  370. logger.info(f"Drawing {len(clusterDict)} results")
  371. tic()
  372. for id, (clusterStates, result) in clusterDict.items():
  373. # opencv wants BGR
  374. color = (100,100,100)
  375. if result == Verdict.GOOD:
  376. color = (0,200,0)
  377. elif result == Verdict.BAD:
  378. color = (0,0,200)
  379. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  380. logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
  381. def _init_logger():
  382. logger = logging.getLogger('main')
  383. logger.setLevel(logging.INFO)
  384. handler = logging.StreamHandler(sys.stdout)
  385. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  386. handler.setFormatter(formatter)
  387. logger.addHandler(handler)
  388. def clusterImportantStates(ranking, iteration):
  389. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  390. tic()
  391. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  392. kmeans = KMeans(len(states) // 15, random_state=0, n_init="auto").fit(states)
  393. #dbscan = DBSCAN(eps=5).fit(states)
  394. labels = kmeans.labels_
  395. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  396. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  397. clusterDict = {i : list() for i in range(0,n_clusters)}
  398. strayStates = list()
  399. for i, state in enumerate(ranking):
  400. if labels[i] == -1:
  401. clusterDict[n_clusters + len(strayStates) + 1] = list()
  402. clusterDict[n_clusters + len(strayStates) + 1].append(state)
  403. strayStates.append(state)
  404. continue
  405. clusterDict[labels[i]].append(state)
  406. if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
  407. drawClusters(clusterDict, f"clusters", iteration)
  408. return clusterDict
  409. if __name__ == '__main__':
  410. _init_logger()
  411. logger = logging.getLogger('main')
  412. logger.info("Starting")
  413. testAll = False
  414. num_queries = 0
  415. source = "images/1_full_scaled_down.png"
  416. for ski_position in range(1, num_ski_positions + 1):
  417. for velocity in range(0,num_velocities):
  418. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  419. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  420. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
  421. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
  422. safeStates = set()
  423. unsafeStates = set()
  424. iteration = 0
  425. results = list()
  426. goodVerdicts = 0
  427. badVerdicts = 0
  428. goodVerdictTestCases = list()
  429. badVerdictTestCases = list()
  430. safeClusters = 0
  431. unsafeClusters = 0
  432. eps = 0.1
  433. while True:
  434. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  435. modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  436. if len(results) > 0:
  437. modelCheckingResult.safeStates = results[-1].safeStates
  438. modelCheckingResult.unsafeStates = results[-1].unsafeStates
  439. modelCheckingResult.policy_queries = results[-1].policy_queries
  440. results.append(modelCheckingResult)
  441. logger.info(f"Model Checking Result: {modelCheckingResult}")
  442. if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
  443. logger.info(f"Absolute difference between average estimates is below eps = {eps}... finishing!")
  444. break
  445. ranking = fillStateRanking(f"action_ranking_{iteration:03}")
  446. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  447. try:
  448. clusters = clusterImportantStates(sorted_ranking, iteration)
  449. except Exception as e:
  450. print(e)
  451. break
  452. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  453. clusterResult = dict()
  454. logger.info(f"Running tests")
  455. tic()
  456. for id, cluster in clusters.items():
  457. num_tests = int(factor_tests_per_cluster * len(cluster))
  458. #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  459. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  460. randomStates = [cluster[i] for i in randomStates]
  461. verdictGood = True
  462. for state in randomStates:
  463. x = state[0].x
  464. y = state[0].y
  465. ski_pos = state[0].ski_position
  466. velocity = state[0].velocity
  467. result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  468. num_queries += num_queries_this_test_case
  469. if result == Verdict.BAD:
  470. if testAll:
  471. failingPerCluster[id].append(state)
  472. else:
  473. clusterResult[id] = (cluster, Verdict.BAD)
  474. verdictGood = False
  475. unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  476. badVerdictTestCases.append(state)
  477. elif result == Verdict.GOOD:
  478. goodVerdicts += 1
  479. goodVerdictTestCases.append(state)
  480. if verdictGood:
  481. clusterResult[id] = (cluster, Verdict.GOOD)
  482. safeClusters += 1
  483. safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  484. else:
  485. unsafeClusters += 1
  486. logger.info(f"Tested Cluster: {iteration:03}\tSafe Cluster States : {len(safeStates)}({safeClusters}/{len(clusters)})\tUnsafe Cluster States:{len(unsafeStates)}({unsafeClusters}/{len(clusters)})\tGood Test Cases:{goodVerdicts}\tFailing Test Cases:{badVerdicts}\t{len(safeStates)/len(unsafeStates)} - {goodVerdicts/badVerdicts}")
  487. results[-1].safeStates = len(safeStates)
  488. results[-1].unsafeStates = len(unsafeStates)
  489. results[-1].policy_queries = num_queries
  490. results[-1].safe_cluster = safeClusters
  491. results[-1].unsafe_cluster = unsafeClusters
  492. results[-1].good_verdicts = goodVerdicts
  493. results[-1].bad_verdicts = badVerdicts
  494. # Account for self-loop states after first iteration
  495. if iteration > 0:
  496. results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafeStates + 0.0*results[-2].safeStates)
  497. results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafeStates + 1.0*results[-2].safeStates)
  498. for result in results:
  499. print(result.csv())
  500. if testAll: drawClusters(failingPerCluster, f"failing", iteration)
  501. drawResult(clusterResult, "result", iteration)
  502. drawOntoSkiPosImage(goodVerdictTestCases, (10,255,50), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
  503. drawOntoSkiPosImage(badVerdictTestCases, (0,0,0), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
  504. concatImages(target, iteration)
  505. iteration += 1
  506. for result in results:
  507. print(result.csv())