You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

548 lines
24 KiB

7 months ago
6 months ago
6 months ago
7 months ago
6 months ago
7 months ago
7 months ago
6 months ago
7 months ago
6 months ago
6 months ago
6 months ago
7 months ago
6 months ago
6 months ago
6 months ago
7 months ago
7 months ago
7 months ago
7 months ago
6 months ago
6 months ago
6 months ago
6 months ago
7 months ago
  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import subprocess
  5. import re
  6. from collections import defaultdict
  7. from random import randrange
  8. from ale_py import ALEInterface, SDL_SUPPORT, Action
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. import cv2
  12. import pickle
  13. import queue
  14. from dataclasses import dataclass, field
  15. from sklearn.cluster import KMeans, DBSCAN
  16. from enum import Enum
  17. from copy import deepcopy
  18. import numpy as np
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. #import readchar
  22. from sample_factory.algo.utils.tensor_dict import TensorDict
  23. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  24. import time
  25. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  26. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  27. def tic():
  28. import time
  29. global startTime_for_tictoc
  30. startTime_for_tictoc = time.time()
  31. def toc():
  32. import time
  33. if 'startTime_for_tictoc' in globals():
  34. return time.time() - startTime_for_tictoc
  35. class Verdict(Enum):
  36. INCONCLUSIVE = 1
  37. GOOD = 2
  38. BAD = 3
  39. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  40. def convert(tuples):
  41. return dict(tuples)
  42. @dataclass(frozen=True)
  43. class State:
  44. x: int
  45. y: int
  46. ski_position: int
  47. velocity: int
  48. def default_value():
  49. return {'action' : None, 'choiceValue' : None}
  50. @dataclass(frozen=True)
  51. class StateValue:
  52. ranking: float
  53. choices: dict = field(default_factory=default_value)
  54. @dataclass(frozen=False)
  55. class TestResult:
  56. init_check_pes_min: float
  57. init_check_pes_max: float
  58. init_check_pes_avg: float
  59. init_check_opt_min: float
  60. init_check_opt_max: float
  61. init_check_opt_avg: float
  62. safe_states: int
  63. unsafe_states: int
  64. policy_queries: int
  65. def __str__(self):
  66. return f"""Test Result:
  67. init_check_pes_min: {self.init_check_pes_min}
  68. init_check_pes_max: {self.init_check_pes_max}
  69. init_check_pes_avg: {self.init_check_pes_avg}
  70. init_check_opt_min: {self.init_check_opt_min}
  71. init_check_opt_max: {self.init_check_opt_max}
  72. init_check_opt_avg: {self.init_check_opt_avg}
  73. """
  74. def csv(self, ws=" "):
  75. return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}{self.safeStates}{ws}{self.unsafeStates}{ws}{self.policy_queries}"
  76. def exec(command,verbose=True):
  77. if verbose: print(f"Executing {command}")
  78. system(f"echo {command} >> list_of_exec")
  79. return system(command)
  80. num_tests_per_cluster = 50
  81. factor_tests_per_cluster = 0.2
  82. num_ski_positions = 8
  83. num_velocities = 5
  84. def input_to_action(char):
  85. if char == "0":
  86. return Action.NOOP
  87. if char == "1":
  88. return Action.RIGHT
  89. if char == "2":
  90. return Action.LEFT
  91. if char == "3":
  92. return "reset"
  93. if char == "4":
  94. return "set_x"
  95. if char == "5":
  96. return "set_vel"
  97. if char in ["w", "a", "s", "d"]:
  98. return char
  99. def saveObservations(observations, verdict, testDir):
  100. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  101. if len(observations) < 20:
  102. logger.warn(f"Potentially spurious test case for {testDir}")
  103. testDir = f"{testDir}_pot_spurious"
  104. exec(f"mkdir {testDir}", verbose=False)
  105. for i, obs in enumerate(observations):
  106. img = Image.fromarray(obs)
  107. img.save(f"{testDir}/{i:003}.png")
  108. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  109. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  110. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  111. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  112. try:
  113. for i, r in enumerate(ramDICT[y]):
  114. ale.setRAM(i,r)
  115. ski_position_setting = ski_position_counter[ski_position]
  116. for i in range(0,ski_position_setting[1]):
  117. ale.act(ski_position_setting[0])
  118. ale.setRAM(14,0)
  119. ale.setRAM(25,x)
  120. ale.setRAM(14,180) # TODO
  121. except Exception as e:
  122. print(e)
  123. logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
  124. return (Verdict.INCONCLUSIVE, 0)
  125. num_queries = 0
  126. all_obs = list()
  127. speed_list = list()
  128. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  129. for i in range(0,4):
  130. all_obs.append(resized_obs)
  131. for i in range(0,duration-4):
  132. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  133. all_obs.append(resized_obs)
  134. if i % 4 == 0:
  135. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  136. action = nn_wrapper.query(stack_tensor)
  137. num_queries += 1
  138. ale.act(input_to_action(str(action)))
  139. else:
  140. ale.act(input_to_action(str(action)))
  141. speed_list.append(ale.getRAM()[14])
  142. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  143. #saveObservations(all_obs, Verdict.BAD, testDir)
  144. return (Verdict.BAD, num_queries)
  145. #saveObservations(all_obs, Verdict.GOOD, testDir)
  146. return (Verdict.GOOD, num_queries)
  147. def computeStateRanking(mdp_file, iteration):
  148. logger.info("Computing state ranking")
  149. tic()
  150. prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  151. prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  152. prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  153. prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  154. prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  155. prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
  156. prop += 'Rmax=? [C <= 200]'
  157. results = list()
  158. try:
  159. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
  160. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  161. for line in output:
  162. #print(line)
  163. if "Result" in line and not len(results) >= 6:
  164. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  165. if range_value:
  166. results.append(float(range_value.group(2)))
  167. results.append(float(range_value.group(3)))
  168. else:
  169. value = re.search(r"(.*:)(.*)", line)
  170. results.append(float(value.group(2)))
  171. exec(f"mv action_ranking action_ranking_{iteration:03}")
  172. except subprocess.CalledProcessError as e:
  173. # todo die gracefully if ranking is uniform
  174. print(e.output)
  175. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  176. return TestResult(*tuple(results),0,0,0)
  177. def fillStateRanking(file_name, match=""):
  178. logger.info(f"Parsing state ranking, {file_name}")
  179. tic()
  180. state_ranking = dict()
  181. try:
  182. with open(file_name, "r") as f:
  183. file_content = f.readlines()
  184. for line in file_content:
  185. if not "move=0" in line: continue
  186. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  187. if ranking_value <= 0.1:
  188. continue
  189. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  190. #print("stateMapping", stateMapping)
  191. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  192. choices = {key:float(value) for (key,value) in choices.items()}
  193. #print("choices", choices)
  194. #print("ranking_value", ranking_value)
  195. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  196. #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  197. value = StateValue(ranking_value, choices)
  198. state_ranking[state] = value
  199. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  200. return state_ranking
  201. except EnvironmentError:
  202. print("Ranking file not available. Exiting.")
  203. toc()
  204. sys.exit(-1)
  205. except:
  206. toc()
  207. def createDisjunction(formulas):
  208. return " | ".join(formulas)
  209. def statesFormulaTrimmed(states):
  210. if len(states) == 0: return "false"
  211. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  212. skiPositionGroup = defaultdict(list)
  213. for item in states:
  214. skiPositionGroup[item[2]].append(item)
  215. formulas = list()
  216. for skiPosition, skiPos_group in skiPositionGroup.items():
  217. formula = f"( ski_position={skiPosition} & "
  218. firstVelocity = True
  219. velocityGroup = defaultdict(list)
  220. for item in skiPos_group:
  221. velocityGroup[item[3]].append(item)
  222. for velocity, velocity_group in velocityGroup.items():
  223. if firstVelocity:
  224. firstVelocity = False
  225. else:
  226. formula += " | "
  227. formula += f" (velocity={velocity} & "
  228. firstY = True
  229. yPosGroup = defaultdict(list)
  230. for item in velocity_group:
  231. yPosGroup[item[1]].append(item)
  232. for y, y_group in yPosGroup.items():
  233. if firstY:
  234. firstY = False
  235. else:
  236. formula += " | "
  237. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  238. formula += f"( y={y} & ("
  239. current_x_min = sorted_y_group[0][0]
  240. current_x = sorted_y_group[0][0]
  241. x_ranges = list()
  242. for state in sorted_y_group[1:-1]:
  243. if state[0] - current_x == 1:
  244. current_x = state[0]
  245. else:
  246. x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
  247. current_x_min = state[0]
  248. current_x = state[0]
  249. x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
  250. formula += " | ".join(x_ranges)
  251. formula += ") )"
  252. formula += ")"
  253. formula += ")"
  254. formulas.append(formula)
  255. print(formulas)
  256. sys.exit(1)
  257. return createBalancedDisjunction(formulas)
  258. # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
  259. def pairwise(iterable):
  260. "s -> (s0, s1), (s2, s3), (s4, s5), ..."
  261. a = iter(iterable)
  262. return zip(a, a)
  263. def createBalancedDisjunction(formulas):
  264. if len(formulas) == 0:
  265. return "false"
  266. while len(formulas) > 1:
  267. formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
  268. if len(formulas) % 2 == 1:
  269. formulas_tmp.append(formulas[-1])
  270. formulas = formulas_tmp
  271. return " ".join(formulas)
  272. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  273. logger.info("Creating next prism file")
  274. tic()
  275. initFile = f"{newFile}_no_formulas.prism"
  276. newFile = f"{newFile}_{iteration:03}.prism"
  277. exec(f"cp {initFile} {newFile}", verbose=False)
  278. with open(newFile, "a") as prism:
  279. prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n")
  280. prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n")
  281. prism.write(f"label \"Safe\" = Safe;\n")
  282. prism.write(f"label \"Unsafe\" = Unsafe;\n")
  283. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  284. ale = ALEInterface()
  285. #if SDL_SUPPORT:
  286. # ale.setBool("sound", True)
  287. # ale.setBool("display_screen", True)
  288. # Load the ROM file
  289. ale.loadROM(rom_file)
  290. with open('all_positions_v2.pickle', 'rb') as handle:
  291. ramDICT = pickle.load(handle)
  292. y_ram_setting = 60
  293. x = 70
  294. nn_wrapper = SampleFactoryNNQueryWrapper()
  295. experiment_id = int(time.time())
  296. init_mdp = "velocity_safety"
  297. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  298. markerSize = 1
  299. imagesDir = f"images/testing_{experiment_id}"
  300. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  301. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  302. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
  303. images = dict()
  304. mergedImages = dict()
  305. for ski_position in range(1, num_ski_positions + 1):
  306. for velocity in range(0,num_velocities):
  307. images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
  308. mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
  309. for state in states:
  310. s = state[0]
  311. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  312. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
  313. marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
  314. markerList[(s.ski_position, s.velocity)].append(marker)
  315. for (pos, vel), marker in markerList.items():
  316. #command = f"convert {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png"
  317. #exec(command, verbose=False)
  318. if len(marker) == 0: continue
  319. for m in marker:
  320. images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
  321. mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
  322. for (ski_position, velocity), image in images.items():
  323. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
  324. for ski_position, image in mergedImages.items():
  325. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
  326. def concatImages(prefix, iteration):
  327. logger.info(f"Concatenating images")
  328. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
  329. mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
  330. for vel in range(0, num_velocities):
  331. for pos in range(1, num_ski_positions + 1):
  332. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  333. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  334. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  335. exec(command, verbose=False)
  336. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  337. exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}_merged.png", verbose=False)
  338. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  339. logger.info(f"Concatenating images - DONE")
  340. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  341. """
  342. Useful to draw a set of states, e.g. a single cluster
  343. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  344. logger.info(f"Drawing {len(states)} states onto {target}")
  345. tic()
  346. for state in states:
  347. s = state[0]
  348. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  349. markerList[s.ski_position].append(marker)
  350. for pos, marker in markerList.items():
  351. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  352. exec(command, verbose=False)
  353. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  354. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  355. """
  356. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  357. logger.info(f"Drawing {len(clusterDict)} clusters")
  358. tic()
  359. #for velocity in range(0, num_velocities):
  360. # for ski_position in range(1, num_ski_positions + 1):
  361. # source = "images/1_full_scaled_down.png"
  362. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  363. for _, clusterStates in clusterDict.items():
  364. color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
  365. color = (int(color[0]), int(color[1]), int(color[2]))
  366. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  367. concatImages(target, iteration)
  368. logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
  369. def drawResult(clusterDict, target, iteration):
  370. logger.info(f"Drawing {len(clusterDict)} results")
  371. #for velocity in range(0,num_velocities):
  372. # for ski_position in range(1, num_ski_positions + 1):
  373. # source = "images/1_full_scaled_down.png"
  374. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  375. for _, (clusterStates, result) in clusterDict.items():
  376. # opencv wants BGR
  377. color = (100,100,100)
  378. if result == Verdict.GOOD:
  379. color = (0,200,0)
  380. elif result == Verdict.BAD:
  381. color = (0,0,200)
  382. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  383. concatImages(target, iteration)
  384. logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
  385. def _init_logger():
  386. logger = logging.getLogger('main')
  387. logger.setLevel(logging.INFO)
  388. handler = logging.StreamHandler(sys.stdout)
  389. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  390. handler.setFormatter(formatter)
  391. logger.addHandler(handler)
  392. def clusterImportantStates(ranking, iteration):
  393. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  394. tic()
  395. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  396. #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
  397. #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  398. dbscan = DBSCAN(eps=5).fit(states)
  399. labels = dbscan.labels_
  400. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  401. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  402. clusterDict = {i : list() for i in range(0,n_clusters)}
  403. strayStates = list()
  404. for i, state in enumerate(ranking):
  405. if labels[i] == -1:
  406. clusterDict[n_clusters + len(strayStates) + 1] = list()
  407. clusterDict[n_clusters + len(strayStates) + 1].append(state)
  408. strayStates.append(state)
  409. continue
  410. clusterDict[labels[i]].append(state)
  411. if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
  412. drawClusters(clusterDict, f"clusters", iteration)
  413. return clusterDict
  414. if __name__ == '__main__':
  415. _init_logger()
  416. logger = logging.getLogger('main')
  417. logger.info("Starting")
  418. testAll = False
  419. num_queries = 0
  420. source = "images/1_full_scaled_down.png"
  421. for ski_position in range(1, num_ski_positions + 1):
  422. for velocity in range(0,num_velocities):
  423. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  424. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  425. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
  426. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
  427. safeStates = set()
  428. unsafeStates = set()
  429. iteration = 0
  430. results = list()
  431. eps = 0.1
  432. while True:
  433. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  434. modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  435. if len(results) > 0:
  436. modelCheckingResult.safeStates = results[-1].safeStates
  437. modelCheckingResult.unsafeStates = results[-1].unsafeStates
  438. modelCheckingResult.num_queries = results[-1].num_queries
  439. results.append(modelCheckingResult)
  440. logger.info(f"Model Checking Result: {modelCheckingResult}")
  441. if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
  442. logger.info(f"Absolute difference between average estimates is below eps = {eps}... finishing!")
  443. break
  444. ranking = fillStateRanking(f"action_ranking_{iteration:03}")
  445. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  446. try:
  447. clusters = clusterImportantStates(sorted_ranking, iteration)
  448. except Exception as e:
  449. print(e)
  450. break
  451. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  452. clusterResult = dict()
  453. logger.info(f"Running tests")
  454. tic()
  455. for id, cluster in clusters.items():
  456. num_tests = int(factor_tests_per_cluster * len(cluster))
  457. #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  458. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  459. randomStates = [cluster[i] for i in randomStates]
  460. verdictGood = True
  461. for state in randomStates:
  462. x = state[0].x
  463. y = state[0].y
  464. ski_pos = state[0].ski_position
  465. velocity = state[0].velocity
  466. result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  467. num_queries += num_queries_this_test_case
  468. if result == Verdict.BAD:
  469. if testAll:
  470. failingPerCluster[id].append(state)
  471. else:
  472. clusterResult[id] = (cluster, Verdict.BAD)
  473. verdictGood = False
  474. unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  475. break
  476. if verdictGood:
  477. clusterResult[id] = (cluster, Verdict.GOOD)
  478. safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  479. logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}")
  480. results[-1].safeStates = len(safeStates)
  481. results[-1].unsafeStates = len(unsafeStates)
  482. results[-1].num_queries = num_queries
  483. if testAll: drawClusters(failingPerCluster, f"failing", iteration)
  484. drawResult(clusterResult, "result", iteration)
  485. iteration += 1
  486. for result in results:
  487. print(result.csv())