You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

587 lines
26 KiB

6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
6 months ago
  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import subprocess
  5. import re
  6. from collections import defaultdict
  7. from random import randrange
  8. from ale_py import ALEInterface, SDL_SUPPORT, Action
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. import cv2
  12. import pickle
  13. import queue
  14. from dataclasses import dataclass, field
  15. from sklearn.cluster import KMeans, DBSCAN
  16. from enum import Enum
  17. from copy import deepcopy
  18. import numpy as np
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. #import readchar
  22. from sample_factory.algo.utils.tensor_dict import TensorDict
  23. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  24. import time
  25. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  26. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  27. def tic():
  28. import time
  29. global startTime_for_tictoc
  30. startTime_for_tictoc = time.time()
  31. def toc():
  32. import time
  33. if 'startTime_for_tictoc' in globals():
  34. return time.time() - startTime_for_tictoc
  35. class Verdict(Enum):
  36. INCONCLUSIVE = 1
  37. GOOD = 2
  38. BAD = 3
  39. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  40. def convert(tuples):
  41. return dict(tuples)
  42. @dataclass(frozen=True)
  43. class State:
  44. x: int
  45. y: int
  46. ski_position: int
  47. velocity: int
  48. def default_value():
  49. return {'action' : None, 'choiceValue' : None}
  50. @dataclass(frozen=True)
  51. class StateValue:
  52. ranking: float
  53. choices: dict = field(default_factory=default_value)
  54. @dataclass(frozen=False)
  55. class TestResult:
  56. init_check_pes_min: float
  57. init_check_pes_max: float
  58. init_check_pes_avg: float
  59. init_check_opt_min: float
  60. init_check_opt_max: float
  61. init_check_opt_avg: float
  62. safe_states: int
  63. unsafe_states: int
  64. policy_queries: int
  65. def __str__(self):
  66. return f"""Test Result:
  67. init_check_pes_min: {self.init_check_pes_min}
  68. init_check_pes_max: {self.init_check_pes_max}
  69. init_check_pes_avg: {self.init_check_pes_avg}
  70. init_check_opt_min: {self.init_check_opt_min}
  71. init_check_opt_max: {self.init_check_opt_max}
  72. init_check_opt_avg: {self.init_check_opt_avg}
  73. """
  74. def csv(self, ws=" "):
  75. return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}{self.safeStates}{ws}{self.unsafeStates}{ws}{self.policy_queries}"
  76. def exec(command,verbose=True):
  77. if verbose: print(f"Executing {command}")
  78. system(f"echo {command} >> list_of_exec")
  79. return system(command)
  80. num_tests_per_cluster = 50
  81. factor_tests_per_cluster = 0.2
  82. num_ski_positions = 8
  83. num_velocities = 5
  84. def input_to_action(char):
  85. if char == "0":
  86. return Action.NOOP
  87. if char == "1":
  88. return Action.RIGHT
  89. if char == "2":
  90. return Action.LEFT
  91. if char == "3":
  92. return "reset"
  93. if char == "4":
  94. return "set_x"
  95. if char == "5":
  96. return "set_vel"
  97. if char in ["w", "a", "s", "d"]:
  98. return char
  99. def saveObservations(observations, verdict, testDir):
  100. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  101. if len(observations) < 20:
  102. logger.warn(f"Potentially spurious test case for {testDir}")
  103. testDir = f"{testDir}_pot_spurious"
  104. exec(f"mkdir {testDir}", verbose=False)
  105. for i, obs in enumerate(observations):
  106. img = Image.fromarray(obs)
  107. img.save(f"{testDir}/{i:003}.png")
  108. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  109. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  110. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  111. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  112. try:
  113. for i, r in enumerate(ramDICT[y]):
  114. ale.setRAM(i,r)
  115. ski_position_setting = ski_position_counter[ski_position]
  116. for i in range(0,ski_position_setting[1]):
  117. ale.act(ski_position_setting[0])
  118. ale.setRAM(14,0)
  119. ale.setRAM(25,x)
  120. ale.setRAM(14,180) # TODO
  121. except Exception as e:
  122. print(e)
  123. logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
  124. return (Verdict.INCONCLUSIVE, 0)
  125. num_queries = 0
  126. all_obs = list()
  127. speed_list = list()
  128. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  129. for i in range(0,4):
  130. all_obs.append(resized_obs)
  131. for i in range(0,duration-4):
  132. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  133. all_obs.append(resized_obs)
  134. if i % 4 == 0:
  135. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  136. action = nn_wrapper.query(stack_tensor)
  137. num_queries += 1
  138. ale.act(input_to_action(str(action)))
  139. else:
  140. ale.act(input_to_action(str(action)))
  141. speed_list.append(ale.getRAM()[14])
  142. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  143. #saveObservations(all_obs, Verdict.BAD, testDir)
  144. return (Verdict.BAD, num_queries)
  145. #saveObservations(all_obs, Verdict.GOOD, testDir)
  146. return (Verdict.GOOD, num_queries)
  147. def skiPositionFormulaList(name):
  148. formulas = list()
  149. for i in range(1, num_ski_positions+1):
  150. formulas.append(f"\"{name}_{i}\"")
  151. return createBalancedDisjunction(formulas)
  152. def computeStateRanking(mdp_file, iteration):
  153. logger.info("Computing state ranking")
  154. tic()
  155. prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  156. prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  157. prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  158. prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  159. prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  160. prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
  161. prop += 'Rmax=? [C <= 200]'
  162. results = list()
  163. try:
  164. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
  165. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  166. num_states = 0
  167. for line in output:
  168. #print(line)
  169. if "States:" in line:
  170. num_states = int(line.split(" ")[-1])
  171. if "Result" in line and not len(results) >= 6:
  172. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  173. if range_value:
  174. results.append(float(range_value.group(2)))
  175. results.append(float(range_value.group(3)))
  176. else:
  177. value = re.search(r"(.*:)(.*)", line)
  178. results.append(float(value.group(2)))
  179. exec(f"mv action_ranking action_ranking_{iteration:03}")
  180. except subprocess.CalledProcessError as e:
  181. # todo die gracefully if ranking is uniform
  182. print(e.output)
  183. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  184. return TestResult(*tuple(results),0,0,0), num_states
  185. def fillStateRanking(file_name, match=""):
  186. logger.info(f"Parsing state ranking, {file_name}")
  187. tic()
  188. state_ranking = dict()
  189. try:
  190. with open(file_name, "r") as f:
  191. file_content = f.readlines()
  192. for line in file_content:
  193. if not "move=0" in line: continue
  194. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  195. if ranking_value <= 0.1:
  196. continue
  197. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  198. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  199. choices = {key:float(value) for (key,value) in choices.items()}
  200. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  201. value = StateValue(ranking_value, choices)
  202. state_ranking[state] = value
  203. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  204. return state_ranking
  205. except EnvironmentError:
  206. print("Ranking file not available. Exiting.")
  207. toc()
  208. sys.exit(-1)
  209. except:
  210. toc()
  211. def createDisjunction(formulas):
  212. return " | ".join(formulas)
  213. def statesFormulaTrimmed(states, name):
  214. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  215. skiPositionGroup = defaultdict(list)
  216. for item in states:
  217. skiPositionGroup[item[2]].append(item)
  218. formulas = list()
  219. for skiPosition, skiPos_group in skiPositionGroup.items():
  220. formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & "
  221. firstVelocity = True
  222. velocityGroup = defaultdict(list)
  223. for item in skiPos_group:
  224. velocityGroup[item[3]].append(item)
  225. for velocity, velocity_group in velocityGroup.items():
  226. if firstVelocity:
  227. firstVelocity = False
  228. else:
  229. formula += " | "
  230. formulasPerSkiPosition = list()
  231. formula += f" (velocity={velocity} & "
  232. firstY = True
  233. yPosGroup = defaultdict(list)
  234. yAndXRanges = dict()
  235. for item in velocity_group:
  236. yPosGroup[item[1]].append(item)
  237. for y, y_group in yPosGroup.items():
  238. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  239. #formula += f"( y={y} & ("
  240. current_x_min = sorted_y_group[0][0]
  241. current_x = sorted_y_group[0][0]
  242. x_ranges = list()
  243. for state in sorted_y_group[1:-1]:
  244. if state[0] - current_x == 1:
  245. current_x = state[0]
  246. else:
  247. x_ranges.append(f" ({current_x_min}<=x&x<={current_x})")
  248. current_x_min = state[0]
  249. current_x = state[0]
  250. x_ranges.append(f" ({current_x_min}<=x&x<={sorted_y_group[-1][0]})")
  251. xRangesDisjunction = createBalancedDisjunction(x_ranges)
  252. if xRangesDisjunction in yAndXRanges:
  253. yAndXRanges[xRangesDisjunction].append(y)
  254. else:
  255. yAndXRanges[xRangesDisjunction] = list()
  256. yAndXRanges[xRangesDisjunction].append(y)
  257. for xRange, ys in yAndXRanges.items():
  258. #if firstY:
  259. # firstY = False
  260. #else:
  261. # formula += " | "
  262. sorted_ys = sorted(ys)
  263. if len(ys) == 1:
  264. formulasPerSkiPosition.append(f" ({xRange} & y={ys[0]})")
  265. continue
  266. current_y_min = sorted_ys[0]
  267. current_y = sorted_ys[0]
  268. y_ranges = list()
  269. for y in sorted_ys[1:]:
  270. if y - current_y == 2:
  271. current_y = y
  272. elif abs(y - current_y) > 2:
  273. y_ranges.append(f" ({current_y_min}<=y&y<={current_y})")
  274. current_y_min = y
  275. current_y = y
  276. y_ranges.append(f" ({current_y_min}<=y&y<={sorted_ys[-1]})")
  277. formulasPerSkiPosition.append(f" ({xRange} & ({createBalancedDisjunction(y_ranges)}))")
  278. formula += createBalancedDisjunction(formulasPerSkiPosition)
  279. formula += ")"
  280. formula += ");"
  281. formulas.append(formula)
  282. for i in range(1, num_ski_positions+1):
  283. if i in skiPositionGroup:
  284. continue
  285. formulas.append(f"formula {name}_{i} = false;")
  286. return "\n".join(formulas) + "\n"
  287. # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
  288. def pairwise(iterable):
  289. "s -> (s0, s1), (s2, s3), (s4, s5), ..."
  290. a = iter(iterable)
  291. return zip(a, a)
  292. def createBalancedDisjunction(formulas):
  293. if len(formulas) == 0:
  294. return "false"
  295. while len(formulas) > 1:
  296. formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
  297. if len(formulas) % 2 == 1:
  298. formulas_tmp.append(formulas[-1])
  299. formulas = formulas_tmp
  300. return " ".join(formulas)
  301. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  302. logger.info("Creating next prism file")
  303. tic()
  304. initFile = f"{newFile}_no_formulas.prism"
  305. newFile = f"{newFile}_{iteration:03}.prism"
  306. exec(f"cp {initFile} {newFile}", verbose=False)
  307. with open(newFile, "a") as prism:
  308. prism.write(statesFormulaTrimmed(safeStates, "Safe"))
  309. prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe"))
  310. for i in range(1,num_ski_positions+1):
  311. prism.write(f"label \"Safe_{i}\" = Safe_{i};\n")
  312. prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n")
  313. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  314. ale = ALEInterface()
  315. #if SDL_SUPPORT:
  316. # ale.setBool("sound", True)
  317. # ale.setBool("display_screen", True)
  318. # Load the ROM file
  319. ale.loadROM(rom_file)
  320. with open('all_positions_v2.pickle', 'rb') as handle:
  321. ramDICT = pickle.load(handle)
  322. y_ram_setting = 60
  323. x = 70
  324. nn_wrapper = SampleFactoryNNQueryWrapper()
  325. experiment_id = int(time.time())
  326. init_mdp = "velocity_safety"
  327. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  328. markerSize = 1
  329. imagesDir = f"images/testing_{experiment_id}"
  330. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  331. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  332. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
  333. images = dict()
  334. mergedImages = dict()
  335. for ski_position in range(1, num_ski_positions + 1):
  336. for velocity in range(0,num_velocities):
  337. images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
  338. mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
  339. for state in states:
  340. s = state[0]
  341. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  342. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
  343. marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
  344. markerList[(s.ski_position, s.velocity)].append(marker)
  345. for (pos, vel), marker in markerList.items():
  346. #command = f"convert {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png"
  347. #exec(command, verbose=False)
  348. if len(marker) == 0: continue
  349. for m in marker:
  350. images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
  351. mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
  352. for (ski_position, velocity), image in images.items():
  353. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
  354. for ski_position, image in mergedImages.items():
  355. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
  356. def concatImages(prefix, iteration):
  357. logger.info(f"Concatenating images")
  358. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
  359. mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
  360. for vel in range(0, num_velocities):
  361. for pos in range(1, num_ski_positions + 1):
  362. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  363. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  364. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  365. exec(command, verbose=False)
  366. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  367. exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}_merged.png", verbose=False)
  368. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  369. logger.info(f"Concatenating images - DONE")
  370. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  371. """
  372. Useful to draw a set of states, e.g. a single cluster
  373. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  374. logger.info(f"Drawing {len(states)} states onto {target}")
  375. tic()
  376. for state in states:
  377. s = state[0]
  378. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  379. markerList[s.ski_position].append(marker)
  380. for pos, marker in markerList.items():
  381. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  382. exec(command, verbose=False)
  383. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  384. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  385. """
  386. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  387. logger.info(f"Drawing {len(clusterDict)} clusters")
  388. tic()
  389. #for velocity in range(0, num_velocities):
  390. # for ski_position in range(1, num_ski_positions + 1):
  391. # source = "images/1_full_scaled_down.png"
  392. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  393. for _, clusterStates in clusterDict.items():
  394. color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
  395. color = (int(color[0]), int(color[1]), int(color[2]))
  396. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  397. concatImages(target, iteration)
  398. logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
  399. def drawResult(clusterDict, target, iteration):
  400. logger.info(f"Drawing {len(clusterDict)} results")
  401. #for velocity in range(0,num_velocities):
  402. # for ski_position in range(1, num_ski_positions + 1):
  403. # source = "images/1_full_scaled_down.png"
  404. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  405. for _, (clusterStates, result) in clusterDict.items():
  406. # opencv wants BGR
  407. color = (100,100,100)
  408. if result == Verdict.GOOD:
  409. color = (0,200,0)
  410. elif result == Verdict.BAD:
  411. color = (0,0,200)
  412. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  413. concatImages(target, iteration)
  414. logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
  415. def _init_logger():
  416. logger = logging.getLogger('main')
  417. logger.setLevel(logging.INFO)
  418. handler = logging.StreamHandler(sys.stdout)
  419. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  420. handler.setFormatter(formatter)
  421. logger.addHandler(handler)
  422. def clusterImportantStates(ranking, iteration):
  423. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  424. tic()
  425. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  426. #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
  427. #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  428. dbscan = DBSCAN(eps=5).fit(states)
  429. labels = dbscan.labels_
  430. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  431. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  432. clusterDict = {i : list() for i in range(0,n_clusters)}
  433. strayStates = list()
  434. for i, state in enumerate(ranking):
  435. if labels[i] == -1:
  436. clusterDict[n_clusters + len(strayStates) + 1] = list()
  437. clusterDict[n_clusters + len(strayStates) + 1].append(state)
  438. strayStates.append(state)
  439. continue
  440. clusterDict[labels[i]].append(state)
  441. if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
  442. drawClusters(clusterDict, f"clusters", iteration)
  443. return clusterDict
  444. if __name__ == '__main__':
  445. _init_logger()
  446. logger = logging.getLogger('main')
  447. logger.info("Starting")
  448. testAll = False
  449. num_queries = 0
  450. source = "images/1_full_scaled_down.png"
  451. for ski_position in range(1, num_ski_positions + 1):
  452. for velocity in range(0,num_velocities):
  453. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  454. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  455. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
  456. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
  457. safeStates = set()
  458. unsafeStates = set()
  459. iteration = 0
  460. results = list()
  461. eps = 0.1
  462. while True:
  463. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  464. modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  465. if len(results) > 0:
  466. modelCheckingResult.safeStates = results[-1].safeStates
  467. modelCheckingResult.unsafeStates = results[-1].unsafeStates
  468. modelCheckingResult.policy_queries = results[-1].policy_queries
  469. results.append(modelCheckingResult)
  470. logger.info(f"Model Checking Result: {modelCheckingResult}")
  471. if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
  472. logger.info(f"Absolute difference between average estimates is below eps = {eps}... finishing!")
  473. break
  474. ranking = fillStateRanking(f"action_ranking_{iteration:03}")
  475. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  476. try:
  477. clusters = clusterImportantStates(sorted_ranking, iteration)
  478. except Exception as e:
  479. print(e)
  480. break
  481. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  482. clusterResult = dict()
  483. logger.info(f"Running tests")
  484. tic()
  485. for id, cluster in clusters.items():
  486. num_tests = int(factor_tests_per_cluster * len(cluster))
  487. #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  488. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  489. randomStates = [cluster[i] for i in randomStates]
  490. verdictGood = True
  491. for state in randomStates:
  492. x = state[0].x
  493. y = state[0].y
  494. ski_pos = state[0].ski_position
  495. velocity = state[0].velocity
  496. result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  497. num_queries += num_queries_this_test_case
  498. if result == Verdict.BAD:
  499. if testAll:
  500. failingPerCluster[id].append(state)
  501. else:
  502. clusterResult[id] = (cluster, Verdict.BAD)
  503. verdictGood = False
  504. unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  505. break
  506. if verdictGood:
  507. clusterResult[id] = (cluster, Verdict.GOOD)
  508. safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  509. logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}")
  510. results[-1].safeStates = len(safeStates)
  511. results[-1].unsafeStates = len(unsafeStates)
  512. results[-1].policy_queries = num_queries
  513. # Account for self-loop states after first iteration
  514. if iteration > 0:
  515. results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafeStates + 0.0*results[-2].safeStates)
  516. results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafeStates + 1.0*results[-2].safeStates)
  517. for result in results:
  518. print(result.csv())
  519. if testAll: drawClusters(failingPerCluster, f"failing", iteration)
  520. drawResult(clusterResult, "result", iteration)
  521. iteration += 1
  522. for result in results:
  523. print(result.csv())