You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

376 lines
16 KiB

  1. #!/usr/bin/python3
  2. import re, sys, os, shutil, fileinput, subprocess, argparse
  3. from dataclasses import dataclass, field
  4. import visvis as vv
  5. import numpy as np
  6. import argparse
  7. from translate import translateTransitions, readLabels
  8. from plotting import VisVisPlotter
  9. from simulation import Simulator, Verdict
  10. def convert(tuples):
  11. return dict(tuples)
  12. def getBasename(filename):
  13. return os.path.basename(filename)
  14. def traFileWithIteration(filename, iteration):
  15. return os.path.splitext(filename)[0] + f"_{iteration:03}.tra"
  16. def copyFile(filename, newFilename):
  17. shutil.copy(filename, newFilename)
  18. def execute(command, verbose=False):
  19. if verbose: print(f"Executing {command}")
  20. os.system(command)
  21. @dataclass(frozen=True)
  22. class State:
  23. id: int
  24. x: float
  25. x_vel: float
  26. y: float
  27. y_vel: float
  28. z: float
  29. z_vel: float
  30. def default_value():
  31. return {'action' : None, 'choiceValue' : None}
  32. @dataclass(frozen=True)
  33. class StateValue:
  34. ranking: float
  35. choices: dict = field(default_factory=default_value)
  36. @dataclass(frozen=False)
  37. class TestResult:
  38. prob_pes_min: float
  39. prob_pes_max: float
  40. prob_pes_avg: float
  41. prob_opt_min: float
  42. prob_opt_max: float
  43. prob_opt_avg: float
  44. min_min: float
  45. min_max: float
  46. def csv(self, ws=" "):
  47. return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}"
  48. def parseStrategy(strategyFile, allStateActionPairs, time_index=0):
  49. strategy = dict()
  50. with open(strategyFile) as strategyLines:
  51. for line in strategyLines:
  52. line = line.replace("(","").replace(")","").replace("\n", "")
  53. explode = re.split(",|=", line)
  54. stateId = int(explode[0]) + 3
  55. if stateId < 3: continue
  56. if int(explode[1]) != time_index: continue
  57. try:
  58. strategy[stateId] = allStateActionPairs[stateId].index(explode[2])
  59. except KeyError as e:
  60. pass
  61. return strategy
  62. def queryStrategy(strategy, stateId):
  63. try:
  64. return strategy[stateId]
  65. except:
  66. return -1
  67. def callTempest(files, reward, bound=3):
  68. property_str = "!(\"failed\" | \"reached\")"
  69. if True:
  70. prop = f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
  71. prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
  72. prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
  73. prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
  74. prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
  75. prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
  76. prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str} );"
  77. prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );"
  78. else:
  79. prop = f"filter(min, Rmin=? [ C<={bound} ], {property_str} );"
  80. prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str} );"
  81. prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str} );"
  82. prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );"
  83. prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str} );"
  84. prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str} );"
  85. command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' "
  86. results = list()
  87. try:
  88. output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
  89. for line in output:
  90. if "Result" in line and not len(results) >= 10:
  91. range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
  92. if range_value:
  93. results.append(float(range_value.group(2)))
  94. results.append(float(range_value.group(3)))
  95. else:
  96. value = re.search(r"(.*:)(.*)", line)
  97. results.append(float(value.group(2)))
  98. except subprocess.CalledProcessError as e:
  99. print(e.output)
  100. #results.append(-1)
  101. #results.append(-1)
  102. return TestResult(*(tuple(results)))
  103. def parseRanking(filename, allStates):
  104. state_ranking = dict()
  105. try:
  106. with open(filename, "r") as f:
  107. filecontent = f.readlines()
  108. for line in filecontent:
  109. stateId = int(re.findall(r"^\d+", line)[0])
  110. values = re.findall(r":(-?\d+\.?\d*),?", line)
  111. ranking_value = float(values[0])
  112. choices = {i : float(value) for i,value in enumerate(values[1:])}
  113. state = allStates[stateId]
  114. value = StateValue(ranking_value, choices)
  115. state_ranking[state] = value
  116. if len(state_ranking) == 0: return
  117. all_values = [x.ranking for x in state_ranking.values()]
  118. max_value = max(all_values)
  119. min_value = min(all_values)
  120. new_state_ranking = {}
  121. for state, value in state_ranking.items():
  122. choices = value.choices
  123. try:
  124. new_value = (value.ranking - min_value) / (max_value - min_value)
  125. except ZeroDivisionError as e:
  126. new_value = 0.0
  127. new_state_ranking[state] = StateValue(new_value, choices)
  128. state_ranking = new_state_ranking
  129. except EnvironmentError:
  130. print("TODO file not available. Exiting.")
  131. sys.exit(1)
  132. return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)}
  133. def parseStateValuations(filename):
  134. all_states = dict()
  135. maxStateId = -1
  136. for i in [0,1,2]:
  137. dummy_values = [i] * 7
  138. all_states[i] = State(*dummy_values)
  139. with open(filename) as stateValuations:
  140. for line in stateValuations:
  141. values = re.findall(r"(-?\d+\.?\d*),?", line)
  142. values = [int(values[0])] + [float(v) for v in values[1:]]
  143. all_states[values[0]] = State(*values)
  144. if values[0] > maxStateId: maxStateId = values[0]
  145. dummy_values = [maxStateId + 1] * 7
  146. all_states[maxStateId + 1] = State(*dummy_values)
  147. return all_states
  148. def parseResults(allStates):
  149. state_to_values = dict()
  150. with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer:
  151. for max_line, min_line in zip(maximizer, minimizer):
  152. max_values = re.findall(r"(-?\d+\.?\d*),?", max_line)
  153. min_values = re.findall(r"(-?\d+\.?\d*),?", min_line)
  154. if max_values[0] != min_values[0]:
  155. print("min/max files do not match.")
  156. assert(False)
  157. stateId = int(max_values[0])
  158. min_result = float(min_values[1])
  159. max_result = float(max_values[1])
  160. value = (min_result, max_result, max_result - min_result)
  161. state_to_values[stateId] = value
  162. return state_to_values
  163. def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration):
  164. stateIdRegex = re.compile(f"^{stateId}\s")
  165. for line in fileinput.input(filename, inplace = True):
  166. if not stateIdRegex.match(line):
  167. print(line, end="")
  168. else:
  169. explode = line.split(" ")
  170. if int(explode[1]) == chosenActionIndex:
  171. print(line, end="")
  172. def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration):
  173. stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()]))
  174. for line in fileinput.input(filename, inplace = True):
  175. result = stateIdsRegex.match(line)
  176. if not result:
  177. print(line, end="")
  178. else:
  179. actionIndex = stateActionPairsToTrim[int(result[0])]
  180. explode = line.split(" ")
  181. if int(explode[1]) == actionIndex:
  182. print(line, end="")
  183. def getTopNStates(rankedStates, n, threshold):
  184. if n != 0:
  185. return dict(list(rankedStates.items())[-n:])
  186. else:
  187. return {state:value for state,value in rankedStates.items() if value.ranking >= threshold}
  188. def getNRandomStates(rankedStates, n, testedStates):
  189. stateIds = [state.id for state in rankedStates.keys()]
  190. notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates])
  191. if len(notYetTestedStates) >= n:
  192. return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)]
  193. else:
  194. return notYetTestedStates
  195. def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False):
  196. all_states = parseStateValuations("MDP_state_valuations")
  197. deadlockStates, reachedStates, maxStateId = readLabels(labFile)
  198. stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
  199. strategy = parseStrategy(straFile, stateToActions)
  200. if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
  201. if plotting: plotter.plotScenario()
  202. copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra")
  203. iteration = 0
  204. #testsPerIteration = refinementSteps
  205. #refinementThreshold =
  206. numTestedStates = 0
  207. totalIterations = 60
  208. testedStates = list()
  209. while iteration < totalIterations:
  210. print(f"{iteration:03}", end="\t")
  211. sys.stdout.flush()
  212. currentTraFile = traFileWithIteration("MDP_" + traFile, iteration)
  213. nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1)
  214. testResult = callTempest(f"{currentTraFile} MDP_{labFile}", "saferew", horizonBound)
  215. state_ranking = parseRanking("action_ranking", all_states)
  216. copyFile("action_ranking", f"action_ranking_{iteration:03}")
  217. copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}")
  218. copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}")
  219. if not ablationTesting:
  220. importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound)
  221. statesToTest = [state.id for state in importantStates.keys()]
  222. statesToPlot = importantStates
  223. else:
  224. statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates))
  225. testedStates += statesToTest
  226. statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest}
  227. copyFile(currentTraFile, nextTraFile)
  228. stateActionPairsToTrim = dict()
  229. for testState in statesToTest:
  230. chosenActionIndex = queryStrategy(strategy, testState)
  231. if chosenActionIndex != -1:
  232. stateActionPairsToTrim[testState] = chosenActionIndex
  233. stateEstimates = parseResults(all_states)
  234. results = [0,0,0]
  235. failureStates = list()
  236. validatedStates = list()
  237. for state, estimates in stateEstimates.items():
  238. if state in deadlockStates or state in reachedStates:
  239. continue
  240. if estimates[0] > 0.05:
  241. results[1] += 1
  242. failureStates.append(all_states[state])
  243. #print(f"{state}: {estimates}")
  244. elif estimates[1] <= 0.05:
  245. results[0] += 1
  246. validatedStates.append(all_states[state])
  247. else:
  248. results[2] += 1
  249. removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration)
  250. print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}")
  251. if results[2] == 0:
  252. sys.exit(0)
  253. numTestedStates += len(statesToTest)
  254. iteration += 1
  255. if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True)
  256. if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6))
  257. if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05")
  258. def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False):
  259. all_states = parseStateValuations("MDP_state_valuations")
  260. deadlockStates, reachedStates, maxStateId = readLabels(labFile)
  261. stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
  262. strategy = parseStrategy(straFile, stateToActions)
  263. if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
  264. if plotting: plotter.plotScenario()
  265. passingStates = list()
  266. randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound)
  267. i = 0
  268. print("Starting with random testing.")
  269. numQueries = 0
  270. failureStates = list()
  271. while numQueries <= maxQueries:
  272. if i >= 500:
  273. if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6))
  274. if plotting: plotter.takeScreenshot(iteration, prefix="random_testing")
  275. if plotting: plotter.turnCamera()
  276. if plotting: input("")
  277. print(f"{numQueries} {len(failureStates)} ")
  278. i = 0
  279. testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest()
  280. i += queriesForThisTestCase
  281. numQueries += queriesForThisTestCase
  282. stateValuation = all_states[testCase]
  283. if testResult == Verdict.FAIL:
  284. failureStates.append(stateValuation)
  285. print(f"{numQueries} {len(failureStates)} ")
  286. def parseArgs():
  287. parser = argparse.ArgumentParser()
  288. parser.add_argument('--tra', type=str, required=True, help='Path to .tra file.')
  289. parser.add_argument('--lab', type=str, required=True, help='Path to .lab file.')
  290. parser.add_argument('--rew', type=str, required=True, help='Path to .rew file.')
  291. parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.')
  292. refinement = parser.add_mutually_exclusive_group(required=True)
  293. refinement.add_argument('--refinement-steps', type=int, default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.')
  294. refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.')
  295. parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.')
  296. parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.')
  297. random_testing = parser.add_mutually_exclusive_group()
  298. random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.")
  299. random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.')
  300. parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.')
  301. parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.')
  302. return parser.parse_args()
  303. if __name__ == '__main__':
  304. args = parseArgs()
  305. traFile = args.tra
  306. labFile = args.lab
  307. straFile = args.stra
  308. rewFile = args.rew
  309. ablationTesting = args.ablation
  310. plotting = args.plotting
  311. stepwisePlotting = args.stepwise
  312. maxQueriesForRandomTesting = args.random
  313. horizonBound = args.bound
  314. refinementSteps = args.refinement_steps
  315. refinementBound = args.refinement_bound
  316. if maxQueriesForRandomTesting == 0: #akward way to test for this...
  317. main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting)
  318. else:
  319. randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting)