You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

531 lines
23 KiB

7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
7 months ago
  1. import sys
  2. import operator
  3. from os import listdir, system
  4. import subprocess
  5. import re
  6. from collections import defaultdict
  7. from random import randrange
  8. from ale_py import ALEInterface, SDL_SUPPORT, Action
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. import cv2
  12. import pickle
  13. import queue
  14. from dataclasses import dataclass, field
  15. from sklearn.cluster import KMeans, DBSCAN
  16. from enum import Enum
  17. from copy import deepcopy
  18. import numpy as np
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. #import readchar
  22. from sample_factory.algo.utils.tensor_dict import TensorDict
  23. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  24. import time
  25. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  26. rom_file = "/home/spranger/research/Skiing/env/lib/python3.10/site-packages/AutoROM/roms/skiing.bin"
  27. def tic():
  28. import time
  29. global startTime_for_tictoc
  30. startTime_for_tictoc = time.time()
  31. def toc():
  32. import time
  33. if 'startTime_for_tictoc' in globals():
  34. return time.time() - startTime_for_tictoc
  35. class Verdict(Enum):
  36. INCONCLUSIVE = 1
  37. GOOD = 2
  38. BAD = 3
  39. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  40. def convert(tuples):
  41. return dict(tuples)
  42. @dataclass(frozen=True)
  43. class State:
  44. x: int
  45. y: int
  46. ski_position: int
  47. velocity: int
  48. def default_value():
  49. return {'action' : None, 'choiceValue' : None}
  50. @dataclass(frozen=True)
  51. class StateValue:
  52. ranking: float
  53. choices: dict = field(default_factory=default_value)
  54. @dataclass(frozen=False)
  55. class TestResult:
  56. init_check_pes_min: float
  57. init_check_pes_max: float
  58. init_check_pes_avg: float
  59. init_check_opt_min: float
  60. init_check_opt_max: float
  61. init_check_opt_avg: float
  62. def __str__(self):
  63. return f"""Test Result:
  64. init_check_pes_min: {self.init_check_pes_min}
  65. init_check_pes_max: {self.init_check_pes_max}
  66. init_check_pes_avg: {self.init_check_pes_avg}
  67. init_check_opt_min: {self.init_check_opt_min}
  68. init_check_opt_max: {self.init_check_opt_max}
  69. init_check_opt_avg: {self.init_check_opt_avg}
  70. """
  71. def csv(self, ws=" "):
  72. return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}"
  73. def exec(command,verbose=True):
  74. if verbose: print(f"Executing {command}")
  75. system(f"echo {command} >> list_of_exec")
  76. return system(command)
  77. num_tests_per_cluster = 50
  78. factor_tests_per_cluster = 0.2
  79. num_ski_positions = 8
  80. num_velocities = 5
  81. def input_to_action(char):
  82. if char == "0":
  83. return Action.NOOP
  84. if char == "1":
  85. return Action.RIGHT
  86. if char == "2":
  87. return Action.LEFT
  88. if char == "3":
  89. return "reset"
  90. if char == "4":
  91. return "set_x"
  92. if char == "5":
  93. return "set_vel"
  94. if char in ["w", "a", "s", "d"]:
  95. return char
  96. def saveObservations(observations, verdict, testDir):
  97. testDir = f"images/testing_{experiment_id}/{verdict.name}_{testDir}_{len(observations)}"
  98. if len(observations) < 20:
  99. logger.warn(f"Potentially spurious test case for {testDir}")
  100. testDir = f"{testDir}_pot_spurious"
  101. exec(f"mkdir {testDir}", verbose=False)
  102. for i, obs in enumerate(observations):
  103. img = Image.fromarray(obs)
  104. img.save(f"{testDir}/{i:003}.png")
  105. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  106. def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
  107. #def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
  108. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  109. testDir = f"{x}_{y}_{ski_position}_{velocity}"
  110. #testDir = f"{x}_{y}_{ski_position}"
  111. for i, r in enumerate(ramDICT[y]):
  112. ale.setRAM(i,r)
  113. ski_position_setting = ski_position_counter[ski_position]
  114. for i in range(0,ski_position_setting[1]):
  115. ale.act(ski_position_setting[0])
  116. ale.setRAM(14,0)
  117. ale.setRAM(25,x)
  118. ale.setRAM(14,180) # TODO
  119. all_obs = list()
  120. speed_list = list()
  121. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  122. for i in range(0,4):
  123. all_obs.append(resized_obs)
  124. for i in range(0,duration-4):
  125. resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
  126. all_obs.append(resized_obs)
  127. if i % 4 == 0:
  128. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  129. action = nn_wrapper.query(stack_tensor)
  130. ale.act(input_to_action(str(action)))
  131. else:
  132. ale.act(input_to_action(str(action)))
  133. speed_list.append(ale.getRAM()[14])
  134. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  135. #saveObservations(all_obs, Verdict.BAD, testDir)
  136. return Verdict.BAD
  137. #saveObservations(all_obs, Verdict.GOOD, testDir)
  138. return Verdict.GOOD
  139. def computeStateRanking(mdp_file, iteration):
  140. logger.info("Computing state ranking")
  141. tic()
  142. prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  143. prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  144. prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  145. prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  146. prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  147. prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
  148. prop += 'Rmax=? [C <= 200]'
  149. results = list()
  150. try:
  151. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
  152. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  153. for line in output:
  154. print(line)
  155. if "Result" in line and not len(results) >= 6:
  156. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  157. if range_value:
  158. results.append(float(range_value.group(2)))
  159. results.append(float(range_value.group(3)))
  160. else:
  161. value = re.search(r"(.*:)(.*)", line)
  162. results.append(float(value.group(2)))
  163. exec(f"mv action_ranking action_ranking_{iteration:03}")
  164. except subprocess.CalledProcessError as e:
  165. # todo die gracefully if ranking is uniform
  166. print(e.output)
  167. logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
  168. return TestResult(*tuple(results))
  169. def fillStateRanking(file_name, match=""):
  170. logger.info(f"Parsing state ranking, {file_name}")
  171. tic()
  172. state_ranking = dict()
  173. try:
  174. with open(file_name, "r") as f:
  175. file_content = f.readlines()
  176. for line in file_content:
  177. if not "move=0" in line: continue
  178. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  179. if ranking_value <= 0.1:
  180. continue
  181. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  182. #print("stateMapping", stateMapping)
  183. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  184. choices = {key:float(value) for (key,value) in choices.items()}
  185. #print("choices", choices)
  186. #print("ranking_value", ranking_value)
  187. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
  188. #state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  189. value = StateValue(ranking_value, choices)
  190. state_ranking[state] = value
  191. logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
  192. return state_ranking
  193. except EnvironmentError:
  194. print("Ranking file not available. Exiting.")
  195. toc()
  196. sys.exit(-1)
  197. except:
  198. toc()
  199. def createDisjunction(formulas):
  200. return " | ".join(formulas)
  201. def statesFormulaTrimmed(states):
  202. if len(states) == 0: return "false"
  203. #states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
  204. skiPositionGroup = defaultdict(list)
  205. for item in states:
  206. skiPositionGroup[item[2]].append(item)
  207. formulas = list()
  208. for skiPosition, skiPos_group in skiPositionGroup.items():
  209. formula = f"( ski_position={skiPosition} & "
  210. firstVelocity = True
  211. velocityGroup = defaultdict(list)
  212. for item in skiPos_group:
  213. velocityGroup[item[3]].append(item)
  214. for velocity, velocity_group in velocityGroup.items():
  215. if firstVelocity:
  216. firstVelocity = False
  217. else:
  218. formula += " | "
  219. formula += f" (velocity={velocity} & "
  220. firstY = True
  221. yPosGroup = defaultdict(list)
  222. for item in velocity_group:
  223. yPosGroup[item[1]].append(item)
  224. for y, y_group in yPosGroup.items():
  225. if firstY:
  226. firstY = False
  227. else:
  228. formula += " | "
  229. sorted_y_group = sorted(y_group, key=lambda s: s[0])
  230. formula += f"( y={y} & ("
  231. current_x_min = sorted_y_group[0][0]
  232. current_x = sorted_y_group[0][0]
  233. x_ranges = list()
  234. for state in sorted_y_group[1:-1]:
  235. if state[0] - current_x == 1:
  236. current_x = state[0]
  237. else:
  238. x_ranges.append(f" ({current_x_min}<= x & x<={current_x})")
  239. current_x_min = state[0]
  240. current_x = state[0]
  241. x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})")
  242. formula += " | ".join(x_ranges)
  243. formula += ") )"
  244. formula += ")"
  245. formula += ")"
  246. formulas.append(formula)
  247. return createBalancedDisjunction(formulas)
  248. # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
  249. def pairwise(iterable):
  250. "s -> (s0, s1), (s2, s3), (s4, s5), ..."
  251. a = iter(iterable)
  252. return zip(a, a)
  253. def createBalancedDisjunction(formulas):
  254. if len(formulas) == 0:
  255. return "false"
  256. while len(formulas) > 1:
  257. formulas_tmp = [f"({f} | {g})" for f,g in pairwise(formulas)]
  258. if len(formulas) % 2 == 1:
  259. formulas_tmp.append(formulas[-1])
  260. formulas = formulas_tmp
  261. return " ".join(formulas)
  262. def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
  263. logger.info("Creating next prism file")
  264. tic()
  265. initFile = f"{newFile}_no_formulas.prism"
  266. newFile = f"{newFile}_{iteration:03}.prism"
  267. exec(f"cp {initFile} {newFile}", verbose=False)
  268. with open(newFile, "a") as prism:
  269. prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n")
  270. prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n")
  271. prism.write(f"label \"Safe\" = Safe;\n")
  272. prism.write(f"label \"Unsafe\" = Unsafe;\n")
  273. logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
  274. ale = ALEInterface()
  275. #if SDL_SUPPORT:
  276. # ale.setBool("sound", True)
  277. # ale.setBool("display_screen", True)
  278. # Load the ROM file
  279. ale.loadROM(rom_file)
  280. with open('all_positions_v2.pickle', 'rb') as handle:
  281. ramDICT = pickle.load(handle)
  282. y_ram_setting = 60
  283. x = 70
  284. nn_wrapper = SampleFactoryNNQueryWrapper()
  285. experiment_id = int(time.time())
  286. init_mdp = "velocity_safety"
  287. exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
  288. markerSize = 1
  289. imagesDir = f"images/testing_{experiment_id}"
  290. def drawOntoSkiPosImage(states, color, target_prefix="cluster_", alpha_factor=1.0):
  291. #markerList = {ski_position:list() for ski_position in range(1,num_ski_positions + 1)}
  292. markerList = {(ski_position, velocity):list() for velocity in range(0, num_velocities) for ski_position in range(1,num_ski_positions + 1)}
  293. images = dict()
  294. mergedImages = dict()
  295. for ski_position in range(1, num_ski_positions + 1):
  296. for velocity in range(0,num_velocities):
  297. images[(ski_position, velocity)] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png")
  298. mergedImages[ski_position] = cv2.imread(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png")
  299. for state in states:
  300. s = state[0]
  301. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  302. #marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'point {s.x},{s.y} '"
  303. marker = [color, alpha_factor * state[1].ranking, (s.x-markerSize, s.y-markerSize), (s.x+markerSize, s.y+markerSize)]
  304. markerList[(s.ski_position, s.velocity)].append(marker)
  305. for (pos, vel), marker in markerList.items():
  306. #command = f"convert {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png {' '.join(marker)} {imagesDir}/{target_prefix}_{pos:02}_{vel:02}_individual.png"
  307. #exec(command, verbose=False)
  308. if len(marker) == 0: continue
  309. for m in marker:
  310. images[(pos,vel)] = cv2.rectangle(images[(pos,vel)], m[2], m[3], m[0], cv2.FILLED)
  311. mergedImages[pos] = cv2.rectangle(mergedImages[pos], m[2], m[3], m[0], cv2.FILLED)
  312. for (ski_position, velocity), image in images.items():
  313. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_{velocity:02}_individual.png", image)
  314. for ski_position, image in mergedImages.items():
  315. cv2.imwrite(f"{imagesDir}/{target_prefix}_{ski_position:02}_individual.png", image)
  316. def concatImages(prefix, iteration):
  317. logger.info(f"Concatenating images")
  318. images = [f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png" for vel in range(0,num_velocities) for pos in range(1,num_ski_positions+1)]
  319. mergedImages = [f"{imagesDir}/{prefix}_{pos:02}_individual.png" for pos in range(1,num_ski_positions+1)]
  320. for vel in range(0, num_velocities):
  321. for pos in range(1, num_ski_positions + 1):
  322. command = f"convert {imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png "
  323. command += f"-pointsize 10 -gravity NorthEast -annotate +8+0 'p{pos:02}v{vel:02}' "
  324. command += f"{imagesDir}/{prefix}_{pos:02}_{vel:02}_individual.png"
  325. exec(command, verbose=False)
  326. exec(f"montage {' '.join(images)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}.png", verbose=False)
  327. exec(f"montage {' '.join(mergedImages)} -geometry +0+0 -tile 8x9 {imagesDir}/{prefix}_{iteration}_merged.png", verbose=False)
  328. #exec(f"sxiv {imagesDir}/{prefix}_{iteration}.png&", verbose=False)
  329. logger.info(f"Concatenating images - DONE")
  330. def drawStatesOntoTiledImage(states, color, target, source="images/1_full_scaled_down.png", alpha_factor=1.0):
  331. """
  332. Useful to draw a set of states, e.g. a single cluster
  333. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  334. logger.info(f"Drawing {len(states)} states onto {target}")
  335. tic()
  336. for state in states:
  337. s = state[0]
  338. marker = f"-fill 'rgba({color}, {alpha_factor * state[1].ranking})' -draw 'rectangle {s.x-markerSize},{s.y-markerSize} {s.x+markerSize},{s.y+markerSize} '"
  339. markerList[s.ski_position].append(marker)
  340. for pos, marker in markerList.items():
  341. command = f"convert {source} {' '.join(marker)} {imagesDir}/{target}_{pos:02}_individual.png"
  342. exec(command, verbose=False)
  343. exec(f"montage {imagesDir}/{target}_*_individual.png -geometry +0+0 -tile x1 {imagesDir}/{target}.png", verbose=False)
  344. logger.info(f"Drawing {len(states)} states onto {target} - Done: took {toc()} seconds")
  345. """
  346. def drawClusters(clusterDict, target, iteration, alpha_factor=1.0):
  347. logger.info(f"Drawing {len(clusterDict)} clusters")
  348. tic()
  349. #for velocity in range(0, num_velocities):
  350. # for ski_position in range(1, num_ski_positions + 1):
  351. # source = "images/1_full_scaled_down.png"
  352. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  353. for _, clusterStates in clusterDict.items():
  354. color = (np.random.choice(range(256)), np.random.choice(range(256)), np.random.choice(range(256)))
  355. color = (int(color[0]), int(color[1]), int(color[2]))
  356. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=alpha_factor)
  357. concatImages(target, iteration)
  358. logger.info(f"Drawing {len(clusterDict)} clusters - DONE: took {toc()} seconds")
  359. def drawResult(clusterDict, target, iteration):
  360. logger.info(f"Drawing {len(clusterDict)} results")
  361. #for velocity in range(0,num_velocities):
  362. # for ski_position in range(1, num_ski_positions + 1):
  363. # source = "images/1_full_scaled_down.png"
  364. # exec(f"cp {source} {imagesDir}/{target}_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  365. for _, (clusterStates, result) in clusterDict.items():
  366. # opencv wants BGR
  367. color = (100,100,100)
  368. if result == Verdict.GOOD:
  369. color = (0,200,0)
  370. elif result == Verdict.BAD:
  371. color = (0,0,200)
  372. drawOntoSkiPosImage(clusterStates, color, target, alpha_factor=0.7)
  373. concatImages(target, iteration)
  374. logger.info(f"Drawing {len(clusterDict)} results - DONE: took {toc()} seconds")
  375. def _init_logger():
  376. logger = logging.getLogger('main')
  377. logger.setLevel(logging.INFO)
  378. handler = logging.StreamHandler(sys.stdout)
  379. formatter = logging.Formatter( '[%(levelname)s] %(module)s - %(message)s')
  380. handler.setFormatter(formatter)
  381. logger.addHandler(handler)
  382. def clusterImportantStates(ranking, iteration):
  383. logger.info(f"Starting to cluster {len(ranking)} states into clusters")
  384. tic()
  385. states = [[s[0].x,s[0].y, s[0].ski_position * 20, s[0].velocity * 20, s[1].ranking] for s in ranking]
  386. #states = [[s[0].x,s[0].y, s[0].ski_position * 30, s[1].ranking] for s in ranking]
  387. #kmeans = KMeans(n_clusters, random_state=0, n_init="auto").fit(states)
  388. dbscan = DBSCAN(eps=5).fit(states)
  389. labels = dbscan.labels_
  390. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  391. logger.info(f"Starting to cluster {len(ranking)} states into clusters - DONE: took {toc()} seconds with {n_clusters} cluster")
  392. clusterDict = {i : list() for i in range(0,n_clusters)}
  393. strayStates = list()
  394. for i, state in enumerate(ranking):
  395. if labels[i] == -1:
  396. clusterDict[n_clusters + len(strayStates) + 1] = list()
  397. clusterDict[n_clusters + len(strayStates) + 1].append(state)
  398. strayStates.append(state)
  399. continue
  400. clusterDict[labels[i]].append(state)
  401. if len(strayStates) > 0: logger.warning(f"{len(strayStates)} stray states with label -1")
  402. drawClusters(clusterDict, f"clusters", iteration)
  403. return clusterDict
  404. if __name__ == '__main__':
  405. _init_logger()
  406. logger = logging.getLogger('main')
  407. logger.info("Starting")
  408. n_clusters = 40
  409. testAll = False
  410. source = "images/1_full_scaled_down.png"
  411. for ski_position in range(1, num_ski_positions + 1):
  412. for velocity in range(0,num_velocities):
  413. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  414. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_{velocity:02}_individual.png", verbose=False)
  415. exec(f"cp {source} {imagesDir}/clusters_{ski_position:02}_individual.png", verbose=False)
  416. exec(f"cp {source} {imagesDir}/result_{ski_position:02}_individual.png", verbose=False)
  417. safeStates = set()
  418. unsafeStates = set()
  419. iteration = 0
  420. results = list()
  421. eps = 0.1
  422. while True:
  423. updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
  424. modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
  425. results.append(modelCheckingResult)
  426. logger.info(f"Model Checking Result: {modelCheckingResult}")
  427. if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
  428. logger.info(f"Absolute difference between average estimates is below eps = {eps}... finishing!")
  429. break
  430. ranking = fillStateRanking(f"action_ranking_{iteration:03}")
  431. sorted_ranking = sorted( (x for x in ranking.items() if x[1].ranking > 0.1), key=lambda x: x[1].ranking)
  432. try:
  433. clusters = clusterImportantStates(sorted_ranking, iteration)
  434. except Exception as e:
  435. print(e)
  436. break
  437. if testAll: failingPerCluster = {i: list() for i in range(0, n_clusters)}
  438. clusterResult = dict()
  439. logger.info(f"Running tests")
  440. tic()
  441. for id, cluster in clusters.items():
  442. num_tests = int(factor_tests_per_cluster * len(cluster))
  443. #logger.info(f"Testing {num_tests} states (from {len(cluster)} states) from cluster {id}")
  444. randomStates = np.random.choice(len(cluster), num_tests, replace=False)
  445. randomStates = [cluster[i] for i in randomStates]
  446. verdictGood = True
  447. for state in randomStates:
  448. x = state[0].x
  449. y = state[0].y
  450. ski_pos = state[0].ski_position
  451. velocity = state[0].velocity
  452. #result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
  453. result = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
  454. if result == Verdict.BAD:
  455. if testAll:
  456. failingPerCluster[id].append(state)
  457. else:
  458. clusterResult[id] = (cluster, Verdict.BAD)
  459. verdictGood = False
  460. unsafeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  461. break
  462. if verdictGood:
  463. clusterResult[id] = (cluster, Verdict.GOOD)
  464. safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
  465. logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}")
  466. if testAll: drawClusters(failingPerCluster, f"failing", iteration)
  467. drawResult(clusterResult, "result", iteration)
  468. iteration += 1
  469. for result in results:
  470. print(result.csv())