You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

605 lines
26 KiB

7 months ago
  1. import sys
  2. import operator
  3. from copy import deepcopy
  4. from os import listdir, system
  5. import subprocess
  6. import re
  7. from collections import defaultdict
  8. from random import randrange
  9. from ale_py import ALEInterface, SDL_SUPPORT, Action
  10. from PIL import Image
  11. from matplotlib import pyplot as plt
  12. import cv2
  13. import pickle
  14. import queue
  15. from dataclasses import dataclass, field
  16. from sklearn.cluster import KMeans, DBSCAN
  17. from enum import Enum
  18. from copy import deepcopy
  19. import numpy as np
  20. import logging
  21. logger = logging.getLogger(__name__)
  22. #import readchar
  23. from sample_factory.algo.utils.tensor_dict import TensorDict
  24. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  25. import time
  26. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  27. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  28. def tic():
  29. import time
  30. global startTime_for_tictoc
  31. startTime_for_tictoc = time.time()
  32. def toc():
  33. import time
  34. if 'startTime_for_tictoc' in globals():
  35. return time.time() - startTime_for_tictoc
  36. class Verdict(Enum):
  37. INCONCLUSIVE = 1
  38. GOOD = 2
  39. BAD = 3
  40. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  41. def convert(tuples):
  42. return dict(tuples)
  43. @dataclass(frozen=True)
  44. class State:
  45. x: int
  46. y: int
  47. ski_position: int
  48. velocity: int
  49. def default_value():
  50. return {'action' : None, 'choiceValue' : None}
  51. @dataclass(frozen=True)
  52. class StateValue:
  53. ranking: float
  54. choices: dict = field(default_factory=default_value)
  55. @dataclass(frozen=False)
  56. class TestResult:
  57. init_check_pes_min: float
  58. init_check_pes_max: float
  59. init_check_pes_avg: float
  60. init_check_opt_min: float
  61. init_check_opt_max: float
  62. init_check_opt_avg: float
  63. safe_states: int
  64. unsafe_states: int
  65. safe_cluster: int
  66. unsafe_cluster: int
  67. good_verdicts: int
  68. bad_verdicts: int
  69. policy_queries: int
  70. def __str__(self):
  71. return f"""Test Result:
  72. init_check_pes_min: {self.init_check_pes_min}
  73. init_check_pes_max: {self.init_check_pes_max}
  74. init_check_pes_avg: {self.init_check_pes_avg}
  75. init_check_opt_min: {self.init_check_opt_min}
  76. init_check_opt_max: {self.init_check_opt_max}
  77. init_check_opt_avg: {self.init_check_opt_avg}
  78. """
  79. @staticmethod
  80. def csv_header(ws=" "):
  81. string = f"pesmin{ws}pesmax{ws}pesavg{ws}"
  82. string += f"optmin{ws}optmax{ws}optavg{ws}"
  83. string += f"sState{ws}uState{ws}"
  84. string += f"sClust{ws}uClust{ws}"
  85. string += f"gVerd{ws}bVerd{ws}queries"
  86. return string
  87. def csv(self):
  88. ws = " "
  89. string = f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}"
  90. string += f"{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}"
  91. ws = "\t"
  92. string += f"{self.safe_states}{ws}{self.unsafe_states}{ws}"
  93. string += f"{self.safe_cluster}{ws}{self.unsafe_cluster}{ws}"
  94. string += f"{self.good_verdicts}{ws}{self.bad_verdicts}{ws}{self.policy_queries}"
  95. return string
  96. def exec(command,verbose=True):
  97. if verbose: print(f"Executing {command}")
  98. system(f"echo {command} >> list_of_exec")
  99. return system(command)
  100. num_tests_per_cluster = 50
  101. #factor_tests_per_cluster = 0.2
  102. num_ski_positions = 8
  103. num_velocities = 5
  104. def input_to_action(char):
  105. if char == "0":
  106. return Action.NOOP
  107. if char == "1":
  108. return Action.RIGHT
  109. if char == "2":
  110. return Action.LEFT
  111. if char == "3":
  112. return "reset"
  113. if char == "4":
  114. return "set_x"
  115. if char == "5":
  116. return "set_vel"
  117. if char in ["w", "a", "s", "d"]:
  118. return char
  119. def saveObservations(observations, verdict, testDir):
  120. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  121. if len(observations) < 20:
  122. logger.warn(f"Potentially spurious test case for {testDir}")
  123. testDir = f"{testDir}_pot_spurious"
  124. exec(f"mkdir {testDir}", verbose=False)
  125. for i, obs in enumerate(observations):
  126. img = Image.fromarray(obs)
  127. img.save(f"{testDir}/{i:003}.png")
  128. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  129. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  130. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  131. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  132. try:
  133. for i, r in enumerate(ramDICT[y]):
  134. ale.setRAM(i,r)
  135. ski_position_setting = ski_position_counter[ski_position]
  136. for i in range(0,ski_position_setting[1]):
  137. ale.act(ski_position_setting[0])
  138. ale.setRAM(14,0)
  139. ale.setRAM(25,x)
  140. ale.setRAM(14,180) # TODO
  141. except Exception as e:
  142. print(e)
  143. logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
  144. return (Verdict.INCONCLUSIVE, 0)
  145. num_queries = 0
  146. all_obs = list()
  147. speed_list = list()
  148. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  149. for i in range(0,4):
  150. all_obs.append(resized_obs)
  151. for i in range(0,duration-4):
  152. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  153. all_obs.append(resized_obs)
  154. if i % 4 == 0:
  155. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  156. action = nn_wrapper.query(stack_tensor)
  157. num_queries += 1
  158. ale.act(input_to_action(str(action)))
  159. else:
  160. ale.act(input_to_action(str(action)))
  161. speed_list.append(ale.getRAM()[14])
  162. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  163. #saveObservations(all_obs, Verdict.BAD, testDir)
  164. return (Verdict.BAD, num_queries)
  165. #saveObservations(all_obs, Verdict.GOOD, testDir)
  166. return (Verdict.GOOD, num_queries)
  167. def skiPositionFormulaList(name):
  168. formulas = list()
  169. for i in range(1, num_ski_positions+1):
  170. formulas.append(f"\"{name}_{i}\"")
  171. return createBalancedDisjunction(formulas)
  172. def computeStateRanking(mdp_file, iteration):
  173. logger.info("Computing state ranking")
  174. tic()
  175. prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  176. prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  177. prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  178. prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  179. prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  180. prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  181. prop += 'Rmax=? [C <= 200]'
  182. results = list()
  183. try:
  184. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
  185. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  186. num_states = 0
  187. for line in output:
  188. #print(line)
  189. if "States:" in line:
  190. num_states = int(line.split(" ")[-1])
  191. if "Result" in line and not len(results) >= 6:
  192. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  193. if range_value:
  194. results.append(float(range_value.group(2)))
  195. results.append(float(range_value.group(3)))
  196. else:
  197. value = re.search(r"(.*:)(.*)", line)
  198. results.append(float(value.group(2)))
  199. exec(f"mv action_ranking action_ranking_{iteration:03}")
  200. except subprocess.CalledProcessError as e:
  201. # todo die gracefully if ranking is uniform
  202. print(e.output)
  203. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  204. return TestResult(*tuple(results),0,0,0,0,0,0,0), num_states
  205. def fillStateRanking(file_name, match=""):
  206. logger.info(f"Parsing state ranking, {file_name}")
  207. tic()
  208. state_ranking = dict()
  209. try:
  210. with open(file_name, "r") as f:
  211. file_content = f.readlines()
  212. for line in file_content:
  213. if not "move=0" in line: continue
  214. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  215. if ranking_value <= 0.1:
  216. continue
  217. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  218. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  219. choices = {key:float(value) for (key,value) in choices.items()}
  220. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  221. value = StateValue(ranking_value, choices)
  222. state_ranking[state] = value
  223. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  224. return state_ranking
  225. except EnvironmentError:
  226. print("Ranking file not available. Exiting.")
  227. toc()
  228. sys.exit(-1)
  229. except:
  230. toc()
  231. def createDisjunction(formulas):
  232. return " | ".join(formulas)
  233. def statesFormulaTrimmed(states, name):
  234. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  235. skiPositionGroup = defaultdict(list)
  236. for item in states:
  237. skiPositionGroup[item[2]].append(item)
  238. formulas = list()
  239. for skiPosition, skiPos_group in skiPositionGroup.items():
  240. formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & "
  241. #print(f"{name} ski_pos:{skiPosition}")
  242. velocityGroup = defaultdict(list)
  243. velocityFormulas = list()
  244. for item in skiPos_group:
  245. velocityGroup[item[3]].append(item)
  246. for velocity, velocity_group in velocityGroup.items():
  247. #print(f"\tvel:{velocity}")
  248. formulasPerSkiPosition = list()
  249. yPosGroup = defaultdict(list)
  250. yFormulas = list()
  251. for item in velocity_group:
  252. yPosGroup[item[1]].append(item)
  253. for y, y_group in yPosGroup.items():
  254. #print(f"\t\ty:{y}")
  255. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  256. current_x_min = sorted_y_group[0][0]
  257. current_x = sorted_y_group[0][0]
  258. x_ranges = list()
  259. for state in sorted_y_group[1:-1]:
  260. if state[0] - current_x == 1:
  261. current_x = state[0]
  262. else:
  263. x_ranges.append(f" ({current_x_min}<=x&x<={current_x})")
  264. current_x_min = state[0]
  265. current_x = state[0]
  266. x_ranges.append(f" {current_x_min}<=x&x<={sorted_y_group[-1][0]}")
  267. yFormulas.append(f" (y={y} & {createBalancedDisjunction(x_ranges)})")
  268. #x_ranges.clear()
  269. #velocityFormulas.append(f"(velocity={velocity} & {createBalancedDisjunction(yFormulas)})")
  270. velocityFormulas.append(f"({createBalancedDisjunction(yFormulas)})")
  271. #yFormulas.clear()
  272. formula += createBalancedDisjunction(velocityFormulas) + ");"
  273. #velocityFormulas.clear()
  274. formulas.append(formula)
  275. for i in range(1, num_ski_positions+1):
  276. if i in skiPositionGroup:
  277. continue
  278. formulas.append(f"formula {name}_{i} = false;")
  279. return "\n".join(formulas) + "\n"
  280. # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
  281. def pairwise(iterable):
  282. "s -> (s0, s1), (s2, s3), (s4, s5), ..."
  283. a = iter(iterable)
  284. return zip(a, a)
  285. def createBalancedDisjunction(formulas):
  286. if len(formulas) == 0:
  287. return "false"
  288. while len(formulas) > 1:
  289. formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
  290. if len(formulas) % 2 == 1:
  291. formulas_tmp.append(formulas[-1])
  292. formulas = formulas_tmp
  293. return " ".join(formulas)
  294. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  295. logger.info("Creating next prism file")
  296. tic()
  297. initFile = f"{newFile}_no_formulas.prism"
  298. newFile = f"{newFile}_{iteration:03}.prism"
  299. exec(f"cp {initFile} {newFile}", verbose=False)
  300. with open(newFile, "a") as prism:
  301. prism.write(statesFormulaTrimmed(safeStates, "Safe"))
  302. prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe"))
  303. for i in range(1,num_ski_positions+1):
  304. prism.write(f"label \"Safe_{i}\" = Safe_{i};\n")
  305. prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n")
  306. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  307. ale = ALEInterface()
  308. #if SDL_SUPPORT:
  309. # ale.setBool("sound", True)
  310. # ale.setBool("display_screen", True)
  311. # Load the ROM file
  312. ale.loadROM(rom_file)
  313. with open('all_positions_v2.pickle', 'rb') as handle:
  314. ramDICT = pickle.load(handle)
  315. y_ram_setting = 60
  316. x = 70
  317. nn_wrapper = SampleFactoryNNQueryWrapper()
  318. experiment_id = int(time.time())
  319. init_mdp = "velocity_safety"
  320. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  321. imagesDir = f"images/testing_{experiment_id}"
  322. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0, markerSize=1, drawCircle=False):
  323. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  324. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
  325. images = dict()
  326. mergedImages = dict()
  327. for ski_position in range(1, num_ski_positions + 1):
  328. for velocity in range(0,num_velocities):
  329. images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
  330. mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
  331. for state in states:
  332. s = state[0]
  333. marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
  334. markerList[(s.ski_position, s.velocity)].append(marker)
  335. for (pos, vel), marker in markerList.items():
  336. if len(marker) == 0: continue
  337. if drawCircle:
  338. for m in marker:
  339. images[(pos,vel)] = cv2.circle(images[(pos,vel)], m[2], 1, m[0], thickness=-1)
  340. mergedImages[pos] = cv2.circle(mergedImages[pos], m[2], 1, m[0], thickness=-1)
  341. else:
  342. for m in marker:
  343. images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
  344. mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
  345. for (ski_position, velocity), image in images.items():
  346. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
  347. for ski_position, image in mergedImages.items():
  348. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
  349. def concatImages(prefix, iteration):
  350. logger.info(f"Concatenating images")
  351. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
  352. mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
  353. for vel in range(0, num_velocities):
  354. for pos in range(1, num_ski_positions + 1):
  355. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  356. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  357. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  358. exec(command, verbose=False)
  359. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}.png", verbose=False)
  360. exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration:03}_merged.png", verbose=False)
  361. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  362. logger.info(f"Concatenating images - DONE")
  363. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  364. """
  365. Useful to draw a set of states, e.g. a single cluster
  366. TODO
  367. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  368. logger.info(f"Drawing {len(states)} states onto {target}")
  369. tic()
  370. for state in states:
  371. s = state[0]
  372. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  373. markerList[s.ski_position].append(marker)
  374. for pos, marker in markerList.items():
  375. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  376. exec(command, verbose=False)
  377. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  378. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  379. """
  380. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  381. logger.info(f"Drawing {len(clusterDict)} clusters")
  382. tic()
  383. for _, clusterStates in clusterDict.items():
  384. color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
  385. color = (int(color[0]), int(color[1]), int(color[2]))
  386. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  387. concatImages(target, iteration)
  388. logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
  389. def drawResult(clusterDict, target, iteration, drawnCluster=set()):
  390. logger.info(f"Drawing {len(clusterDict)} results")
  391. tic()
  392. for id, (clusterStates, result) in clusterDict.items():
  393. if id in drawnCluster: continue
  394. # opencv wants BGR
  395. color = (100,100,100)
  396. if result == Verdict.GOOD:
  397. color = (0,200,0)
  398. elif result == Verdict.BAD:
  399. color = (0,0,200)
  400. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  401. logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
  402. def _init_logger():
  403. logger = logging.getLogger('main')
  404. logger.setLevel(logging.INFO)
  405. handler = logging.StreamHandler(sys.stdout)
  406. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  407. handler.setFormatter(formatter)
  408. logger.addHandler(handler)
  409. def clusterImportantStates(ranking, iteration):
  410. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  411. tic()
  412. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  413. #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
  414. kmeans = KMeans(len(states) // 15, random_state=0, n_init="auto").fit(states)
  415. #dbscan = DBSCAN(eps=5).fit(states)
  416. #labels = dbscan.labels_
  417. labels = kmeans.labels_
  418. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  419. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  420. clusterDict = {i : list() for i in range(0,n_clusters)}
  421. strayStates = list()
  422. for i, state in enumerate(ranking):
  423. if labels[i] == -1:
  424. clusterDict[n_clusters + len(strayStates) + 1] = list()
  425. clusterDict[n_clusters + len(strayStates) + 1].append(state)
  426. strayStates.append(state)
  427. continue
  428. clusterDict[labels[i]].append(state)
  429. if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
  430. #drawClusters(clusterDict, f"clusters", iteration)
  431. return clusterDict
  432. def run_experiment(factor_tests_per_cluster):
  433. logger.info("Starting")
  434. num_queries = 0
  435. source = "images/1_full_scaled_down.png"
  436. for ski_position in range(1, num_ski_positions + 1):
  437. for velocity in range(0,num_velocities):
  438. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  439. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  440. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
  441. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
  442. goodVerdicts = 0
  443. badVerdicts = 0
  444. goodVerdictTestCases = list()
  445. badVerdictTestCases = list()
  446. safeClusters = 0
  447. unsafeClusters = 0
  448. safeStates = set()
  449. unsafeStates = set()
  450. iteration = 0
  451. results = list()
  452. eps = 0.1
  453. updatePrismFile(init_mdp, iteration, set(), set())
  454. #modelCheckingResult, numStates = TestResult(0,0,0,0,0,0,0,0,0,0,0,0,0), 10
  455. modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_000.prism", iteration)
  456. results.append(modelCheckingResult)
  457. ranking = fillStateRanking(f"action_ranking_000")
  458. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  459. try:
  460. clusters = clusterImportantStates(sorted_ranking, iteration)
  461. except Exception as e:
  462. print(e)
  463. sys.exit(-1)
  464. clusterResult = dict()
  465. logger.info(f"Running tests")
  466. tic()
  467. num_cluster_tested = 0
  468. iteration = 0
  469. drawnCluster = set()
  470. for id, cluster in clusters.items():
  471. num_tests = int(factor_tests_per_cluster * len(cluster))
  472. if num_tests == 0: num_tests = 1
  473. logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  474. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  475. randomStates = [cluster[i] for i in randomStates]
  476. verdictGood = True
  477. for state in randomStates:
  478. x = state[0].x
  479. y = state[0].y
  480. ski_pos = state[0].ski_position
  481. velocity = state[0].velocity
  482. result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  483. num_queries += num_queries_this_test_case
  484. if result == Verdict.BAD:
  485. clusterResult[id] = (cluster, Verdict.BAD)
  486. verdictGood = False
  487. unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  488. badVerdicts += 1
  489. badVerdictTestCases.append(state)
  490. elif result == Verdict.GOOD:
  491. goodVerdicts += 1
  492. goodVerdictTestCases.append(state)
  493. if verdictGood:
  494. clusterResult[id] = (cluster, Verdict.GOOD)
  495. safeClusters += 1
  496. safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  497. else:
  498. unsafeClusters += 1
  499. results[-1].safe_states = len(safeStates)
  500. results[-1].unsafe_states = len(unsafeStates)
  501. results[-1].policy_queries = num_queries
  502. results[-1].safe_cluster = safeClusters
  503. results[-1].unsafe_cluster = unsafeClusters
  504. results[-1].good_verdicts = goodVerdicts
  505. results[-1].bad_verdicts = badVerdicts
  506. num_cluster_tested += 1
  507. if num_cluster_tested % (len(clusters)//20) == 0:
  508. iteration += 1
  509. logger.info(f"Tested Cluster: {num_cluster_tested:03}\tSafe Cluster States : {len(safeStates)}({safeClusters}/{len(clusters)})\tUnsafe Cluster States:{len(unsafeStates)}({unsafeClusters}/{len(clusters)})\tGood Test Cases:{goodVerdicts}\tFailing Test Cases:{badVerdicts}\t{len(safeStates)/len(unsafeStates)} - {goodVerdicts/badVerdicts}")
  510. drawResult(clusterResult, "result", iteration, drawnCluster)
  511. drawOntoSkiPosImage(goodVerdictTestCases, (10,255,50), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
  512. drawOntoSkiPosImage(badVerdictTestCases, (0,0,0), "result", alpha_factor=0.7, markerSize=0, drawCircle=True)
  513. concatImages("result", iteration)
  514. drawnCluster.update(clusterResult.keys())
  515. #updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  516. #modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  517. results.append(deepcopy(modelCheckingResult))
  518. logger.info(f"Model Checking Result: {modelCheckingResult}")
  519. # Account for self-loop states after first iteration
  520. if iteration > 0:
  521. results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafe_states + 0.0*results[-2].safe_states)
  522. results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafe_states + 1.0*results[-2].safe_states)
  523. print(TestResult.csv_header())
  524. for result in results[:-1]:
  525. print(result.csv())
  526. with open(f"data_new_method_{factor_tests_per_cluster}", "w") as f:
  527. f.write(TestResult.csv_header() + "\n")
  528. for result in results[:-1]:
  529. f.write(result.csv() + "\n")
  530. _init_logger()
  531. logger = logging.getLogger('main')
  532. if __name__ == '__main__':
  533. for factor_tests_per_cluster in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
  534. run_experiment(factor_tests_per_cluster)