commit
						b3be734b8f
					
				 4 changed files with 667 additions and 0 deletions
			
			
		- 
					88plotting.py
- 
					68simulation.py
- 
					376test_model.py
- 
					135translate.py
| @ -0,0 +1,88 @@ | |||
| #!/usr/bin/python3 | |||
| 
 | |||
| import visvis as vv | |||
| import numpy as np | |||
| 
 | |||
| import time | |||
| 
 | |||
| def translateArrayToVV(array): | |||
|     scaling = (float(array[0][1]) - float(array[0][0]), float(array[2][1]) - float(array[2][0]), float(array[4][1]) - float(array[4][0])) | |||
|     translation = (float(array[0][0]) + scaling[0] * 0.5, float(array[2][0]) + scaling[0] * 0.5, float(array[4][0]) + scaling[0] * 0.5) | |||
|     return translation, scaling | |||
| 
 | |||
| class VisVisPlotter(): | |||
|     def __init__(self, stateValuations, goalStates, deadlockStates, stepwise): | |||
|         self.app = vv.use() | |||
|         self.fig = vv.clf() | |||
|         self.ax = vv.cla() | |||
|         self.stateColor = (0,0,0.75,0.4) | |||
|         self.failColor = (0.8,0.8,0.8,1.0)  # (1,0,0,0.8) | |||
|         self.goalColor = (0,1,0,1.0) | |||
|         self.stateScaling = (2,2,2) #(.5,.5,.5) | |||
| 
 | |||
|         self.plotedStates = set() | |||
|         self.plotedMeshes = set() | |||
|         self.stepwise = stepwise | |||
| 
 | |||
|         auxStates = [0,1,2,25518] | |||
|         self.goals = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in goalStates and not stateId in auxStates]) | |||
|         self.fails = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in deadlockStates and not stateId in auxStates]) | |||
| 
 | |||
|     def plotScenario(self, saveScreenshot=True): | |||
|         for goal in self.goals: | |||
|             state = vv.solidBox(goal, scaling=self.stateScaling) | |||
|             state.faceColor = self.goalColor | |||
|         for fail in self.fails: | |||
|             state = vv.solidBox(fail, scaling=self.stateScaling) | |||
|             state.faceColor = self.failColor | |||
| 
 | |||
|         self.ax.SetLimits((-16,16),(-10,10),(-7,7)) | |||
|         self.ax.SetView({'zoom':0.025, 'elevation':20, 'azimuth':30}) | |||
|         if saveScreenshot: vv.screenshot("000.png", sf=3, bg='w', ob=vv.gcf()) | |||
| 
 | |||
|     def run(self): | |||
|         self.ax.SetLimits((-16,16),(-10,10),(-7,7)) | |||
|         self.app.Run() | |||
| 
 | |||
|     def clear(self): | |||
|         axes = vv.gca() | |||
|         axes.Clear() | |||
| 
 | |||
|     def plotStates(self, states, coloring="", removeMeshes=False): | |||
|         if self.stepwise and removeMeshes: | |||
|             self.clear() | |||
|             self.plotScenario(saveScreenshot=False) | |||
|         if not coloring: | |||
|             coloring = self.stateColor | |||
|         plotedRegions = set() | |||
|         for state in states: | |||
|             if state in self.plotedStates: continue # what are the implications of this? | |||
|             coordinates = (state.x, state.y, state.z) | |||
|             print(f"plotting {state} at {coordinates}") | |||
|             if coordinates in plotedRegions: continue | |||
|             state = vv.solidBox(coordinates, scaling=(1,1,1))#(0.5,0.5,0.5)) | |||
|             state.faceColor = coloring | |||
|             plotedRegions.add(coordinates) | |||
|             if self.stepwise: | |||
|                 self.plotedMeshes.add(state) | |||
|         self.plotedStates.update(states) | |||
| 
 | |||
|     def takeScreenshot(self, iteration, prefix=""): | |||
|         self.ax.SetLimits((-16,16),(-10,10),(-7,7)) | |||
|         #config = [(90, 00), (45, 00), (0, 00), (-45, 00), (-90, 00)] | |||
|         config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] | |||
|         for elevation, azimuth in config: | |||
|             filename = f"{prefix}{'_' if prefix else ''}{iteration:03}_{elevation}_{azimuth}.png" | |||
|             print(f"Saving Screenshot to {filename}") | |||
|             self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) | |||
|             vv.screenshot(filename, sf=3, bg='w', ob=vv.gcf()) | |||
| 
 | |||
|     def turnCamera(self): | |||
|         config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] | |||
|         self.ax.SetLimits((-16,16),(-10,10),(-7,7)) | |||
|         for elevation, azimuth in config: | |||
|             self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) | |||
| 
 | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     main() | |||
| @ -0,0 +1,68 @@ | |||
| import sys | |||
| from enum import Flag, auto | |||
| 
 | |||
| import numpy as np | |||
| 
 | |||
| 
 | |||
| class Verdict(Flag): | |||
|     INCONCLUSIVE = auto() | |||
|     PASS = auto() | |||
|     FAIL = auto() | |||
| 
 | |||
| 
 | |||
| 
 | |||
| class Simulator(): | |||
|     def __init__(self, allStateActionPairs, strategy, deadlockStates, reachedStates, bound=3, numSimulations=1): | |||
|         self.allStateActionPairs = { ( pair.state_id, pair.action_id ) : pair.next_state_probabilities for pair in allStateActionPairs } | |||
|         self.strategy = strategy | |||
|         self.deadlockStates = deadlockStates | |||
|         self.reachedStates = reachedStates | |||
| 
 | |||
|         #print(f"Deadlock: {self.deadlockStates}") | |||
|         #print(f"GoalStates: {self.reachedStates}") | |||
| 
 | |||
| 
 | |||
|         self.bound = bound | |||
|         self.numSimulations = numSimulations | |||
| 
 | |||
|         allStates = set([state.state_id for state in allStateActionPairs]) | |||
|         allStates = allStates.difference(set(deadlockStates)) | |||
|         allStates = allStates.difference(set(reachedStates)) | |||
|         self.allStates = np.array(list(allStates)) | |||
| 
 | |||
|     def _pickRandomTestCase(self): | |||
|         testCase = np.random.choice(self.allStates, 1)[0] | |||
|         #self.allStates = np.delete(self.allStates, testCase) | |||
|         return testCase | |||
| 
 | |||
|     def _simulate(self, initialStateId): | |||
|         i = 0 | |||
| 
 | |||
|         actionId = self.strategy[initialStateId] | |||
|         nextStatePair = (initialStateId, actionId) | |||
| 
 | |||
|         while i < self.bound: | |||
|             i += 1 | |||
|             nextStateProbabilities = self.allStateActionPairs[nextStatePair] | |||
|             weights = list() | |||
|             nextStateIds = list() | |||
|             for nextStateId, probability in nextStateProbabilities.items(): | |||
|                 weights.append(probability) | |||
|                 nextStateIds.append(nextStateId) | |||
|             nextStateId = np.random.choice(nextStateIds, 1, p=weights)[0] | |||
|             if nextStateId in self.deadlockStates: | |||
|                 return Verdict.FAIL, i | |||
|             if nextStateId in self.reachedStates: | |||
|                 return Verdict.PASS, i | |||
|             nextStatePair = (nextStateId, self.strategy[nextStateId]) | |||
|         return Verdict.INCONCLUSIVE, i | |||
| 
 | |||
|     def runTest(self): | |||
|         testCase = self._pickRandomTestCase() | |||
| 
 | |||
|         histogram = [0,0,0] | |||
|         for i in range(self.numSimulations): | |||
|             result, numQueries = self._simulate(testCase) | |||
|             if result == Verdict.FAIL: | |||
|                 return testCase, Verdict.FAIL, numQueries | |||
|             return testCase, Verdict.INCONCLUSIVE, numQueries | |||
| @ -0,0 +1,376 @@ | |||
| #!/usr/bin/python3 | |||
| 
 | |||
| import re, sys, os, shutil, fileinput, subprocess, argparse | |||
| from dataclasses import dataclass, field | |||
| import visvis as vv | |||
| import numpy as np | |||
| import argparse | |||
| 
 | |||
| from translate import translateTransitions, readLabels | |||
| from plotting import VisVisPlotter | |||
| from simulation import Simulator, Verdict | |||
| 
 | |||
| 
 | |||
| def convert(tuples): | |||
|     return dict(tuples) | |||
| 
 | |||
| def getBasename(filename): | |||
|     return os.path.basename(filename) | |||
| 
 | |||
| def traFileWithIteration(filename, iteration): | |||
|     return os.path.splitext(filename)[0] + f"_{iteration:03}.tra" | |||
| 
 | |||
| def copyFile(filename, newFilename): | |||
|     shutil.copy(filename, newFilename) | |||
| 
 | |||
| 
 | |||
| def execute(command, verbose=False): | |||
|     if verbose: print(f"Executing {command}") | |||
|     os.system(command) | |||
| 
 | |||
| @dataclass(frozen=True) | |||
| class State: | |||
|     id: int | |||
|     x: float | |||
|     x_vel: float | |||
|     y: float | |||
|     y_vel: float | |||
|     z: float | |||
|     z_vel: float | |||
| 
 | |||
| def default_value(): | |||
|     return {'action' : None, 'choiceValue' : None} | |||
| 
 | |||
| 
 | |||
| @dataclass(frozen=True) | |||
| class StateValue: | |||
|     ranking: float | |||
|     choices: dict = field(default_factory=default_value) | |||
| 
 | |||
| @dataclass(frozen=False) | |||
| class TestResult: | |||
|     prob_pes_min: float | |||
|     prob_pes_max: float | |||
|     prob_pes_avg: float | |||
|     prob_opt_min: float | |||
|     prob_opt_max: float | |||
|     prob_opt_avg: float | |||
|     min_min: float | |||
|     min_max: float | |||
| 
 | |||
|     def csv(self, ws=" "): | |||
|         return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}" | |||
| 
 | |||
| def parseStrategy(strategyFile, allStateActionPairs, time_index=0): | |||
|     strategy = dict() | |||
|     with open(strategyFile) as strategyLines: | |||
|         for line in strategyLines: | |||
|             line = line.replace("(","").replace(")","").replace("\n", "") | |||
|             explode = re.split(",|=", line) | |||
|             stateId = int(explode[0]) + 3 | |||
|             if stateId < 3: continue | |||
|             if int(explode[1]) != time_index: continue | |||
|             try: | |||
|                 strategy[stateId] = allStateActionPairs[stateId].index(explode[2]) | |||
|             except KeyError as e: | |||
|                 pass | |||
| 
 | |||
|     return strategy | |||
| 
 | |||
| def queryStrategy(strategy, stateId): | |||
|     try: | |||
|         return strategy[stateId] | |||
|     except: | |||
|         return -1 | |||
| 
 | |||
| 
 | |||
| def callTempest(files, reward, bound=3): | |||
|     property_str = "!(\"failed\" | \"reached\")" | |||
|     if True: | |||
|         prop =  f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" | |||
|         prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str}  );" | |||
|     else: | |||
|         prop =  f"filter(min, Rmin=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str}  );" | |||
|         prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str}  );" | |||
|     command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' " | |||
| 
 | |||
|     results = list() | |||
|     try: | |||
|         output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') | |||
|         for line in output: | |||
|             if "Result" in line and not len(results) >= 10: | |||
|                 range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) | |||
|                 if range_value: | |||
|                     results.append(float(range_value.group(2))) | |||
|                     results.append(float(range_value.group(3))) | |||
|                 else: | |||
|                     value = re.search(r"(.*:)(.*)", line) | |||
|                     results.append(float(value.group(2))) | |||
|     except subprocess.CalledProcessError as e: | |||
|         print(e.output) | |||
|     #results.append(-1) | |||
|     #results.append(-1) | |||
|     return TestResult(*(tuple(results))) | |||
| 
 | |||
| def parseRanking(filename, allStates): | |||
|     state_ranking = dict() | |||
|     try: | |||
|         with open(filename, "r") as f: | |||
|             filecontent = f.readlines() | |||
|         for line in filecontent: | |||
|             stateId = int(re.findall(r"^\d+", line)[0]) | |||
|             values = re.findall(r":(-?\d+\.?\d*),?", line) | |||
|             ranking_value = float(values[0]) | |||
|             choices = {i : float(value) for i,value in enumerate(values[1:])} | |||
|             state = allStates[stateId] | |||
|             value = StateValue(ranking_value, choices) | |||
|             state_ranking[state] = value | |||
|         if len(state_ranking) == 0: return | |||
|         all_values = [x.ranking for x in state_ranking.values()] | |||
|         max_value = max(all_values) | |||
|         min_value = min(all_values) | |||
|         new_state_ranking = {} | |||
|         for state, value in state_ranking.items(): | |||
|             choices = value.choices | |||
|             try: | |||
|                 new_value = (value.ranking - min_value) / (max_value - min_value) | |||
|             except ZeroDivisionError as e: | |||
|                 new_value = 0.0 | |||
|             new_state_ranking[state] = StateValue(new_value, choices) | |||
|         state_ranking = new_state_ranking | |||
|     except EnvironmentError: | |||
|         print("TODO file not available. Exiting.") | |||
|         sys.exit(1) | |||
|     return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)} | |||
| 
 | |||
| def parseStateValuations(filename): | |||
|     all_states = dict() | |||
|     maxStateId = -1 | |||
|     for i in [0,1,2]: | |||
|         dummy_values = [i] * 7 | |||
|         all_states[i] = State(*dummy_values) | |||
|     with open(filename) as stateValuations: | |||
|         for line in stateValuations: | |||
|             values = re.findall(r"(-?\d+\.?\d*),?", line) | |||
|             values = [int(values[0])] + [float(v) for v in values[1:]] | |||
|             all_states[values[0]] = State(*values) | |||
|             if values[0] > maxStateId: maxStateId = values[0] | |||
|     dummy_values = [maxStateId + 1] * 7 | |||
|     all_states[maxStateId + 1] = State(*dummy_values) | |||
|     return all_states | |||
| 
 | |||
| def parseResults(allStates): | |||
|     state_to_values = dict() | |||
|     with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer: | |||
|         for max_line, min_line in zip(maximizer, minimizer): | |||
|             max_values = re.findall(r"(-?\d+\.?\d*),?", max_line) | |||
|             min_values = re.findall(r"(-?\d+\.?\d*),?", min_line) | |||
|             if max_values[0] != min_values[0]: | |||
|                 print("min/max files do not match.") | |||
|                 assert(False) | |||
|             stateId = int(max_values[0]) | |||
|             min_result = float(min_values[1]) | |||
|             max_result = float(max_values[1]) | |||
|             value = (min_result, max_result, max_result - min_result) | |||
|             state_to_values[stateId] = value | |||
|     return state_to_values | |||
| 
 | |||
| 
 | |||
| def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration): | |||
|     stateIdRegex = re.compile(f"^{stateId}\s") | |||
|     for line in fileinput.input(filename, inplace = True): | |||
|         if not stateIdRegex.match(line): | |||
|             print(line, end="") | |||
|         else: | |||
|             explode = line.split(" ") | |||
|             if int(explode[1]) == chosenActionIndex: | |||
|                 print(line, end="") | |||
| 
 | |||
| def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration): | |||
|     stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()])) | |||
|     for line in fileinput.input(filename, inplace = True): | |||
|         result = stateIdsRegex.match(line) | |||
|         if not result: | |||
|             print(line, end="") | |||
|         else: | |||
|             actionIndex = stateActionPairsToTrim[int(result[0])] | |||
|             explode = line.split(" ") | |||
|             if int(explode[1]) == actionIndex: | |||
|                 print(line, end="") | |||
| 
 | |||
| def getTopNStates(rankedStates, n, threshold): | |||
|     if n != 0: | |||
|         return dict(list(rankedStates.items())[-n:]) | |||
|     else: | |||
|         return {state:value for state,value in rankedStates.items() if value.ranking >= threshold} | |||
| 
 | |||
| def getNRandomStates(rankedStates, n, testedStates): | |||
|     stateIds = [state.id for state in rankedStates.keys()] | |||
|     notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates]) | |||
|     if len(notYetTestedStates) >= n: | |||
|         return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)] | |||
|     else: | |||
|         return notYetTestedStates | |||
| 
 | |||
| def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False): | |||
| 
 | |||
|     all_states = parseStateValuations("MDP_state_valuations") | |||
|     deadlockStates, reachedStates, maxStateId = readLabels(labFile) | |||
|     stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) | |||
|     strategy = parseStrategy(straFile, stateToActions) | |||
| 
 | |||
|     if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) | |||
|     if plotting: plotter.plotScenario() | |||
| 
 | |||
|     copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra") | |||
| 
 | |||
|     iteration = 0 | |||
|     #testsPerIteration = refinementSteps | |||
|     #refinementThreshold = | |||
|     numTestedStates = 0 | |||
|     totalIterations = 60 | |||
| 
 | |||
|     testedStates = list() | |||
|     while iteration < totalIterations: | |||
|         print(f"{iteration:03}", end="\t") | |||
|         sys.stdout.flush() | |||
|         currentTraFile = traFileWithIteration("MDP_" + traFile, iteration) | |||
|         nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1) | |||
|         testResult = callTempest(f"{currentTraFile} MDP_{labFile}",  "saferew", horizonBound) | |||
|         state_ranking = parseRanking("action_ranking", all_states) | |||
|         copyFile("action_ranking", f"action_ranking_{iteration:03}") | |||
|         copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}") | |||
|         copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}") | |||
| 
 | |||
|         if not ablationTesting: | |||
|             importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound) | |||
|             statesToTest = [state.id for state in importantStates.keys()] | |||
|             statesToPlot = importantStates | |||
|         else: | |||
|             statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates)) | |||
|             testedStates += statesToTest | |||
|             statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest} | |||
| 
 | |||
| 
 | |||
|         copyFile(currentTraFile, nextTraFile) | |||
|         stateActionPairsToTrim = dict() | |||
|         for testState in statesToTest: | |||
|             chosenActionIndex = queryStrategy(strategy, testState) | |||
|             if chosenActionIndex != -1: | |||
|                 stateActionPairsToTrim[testState] = chosenActionIndex | |||
|         stateEstimates = parseResults(all_states) | |||
|         results = [0,0,0] | |||
| 
 | |||
|         failureStates = list() | |||
|         validatedStates = list() | |||
|         for state, estimates in stateEstimates.items(): | |||
|             if state in deadlockStates or state in reachedStates: | |||
|                 continue | |||
|             if estimates[0] > 0.05: | |||
|                 results[1] += 1 | |||
|                 failureStates.append(all_states[state]) | |||
|                 #print(f"{state}: {estimates}") | |||
|             elif estimates[1] <= 0.05: | |||
|                 results[0] += 1 | |||
|                 validatedStates.append(all_states[state]) | |||
|             else: | |||
|                 results[2] += 1 | |||
| 
 | |||
| 
 | |||
|         removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration) | |||
|         print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}") | |||
|         if results[2] == 0: | |||
|             sys.exit(0) | |||
|         numTestedStates += len(statesToTest) | |||
|         iteration += 1 | |||
| 
 | |||
|         if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True) | |||
|         if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6)) | |||
|         if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05") | |||
| 
 | |||
| def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False): | |||
|     all_states = parseStateValuations("MDP_state_valuations") | |||
|     deadlockStates, reachedStates, maxStateId = readLabels(labFile) | |||
|     stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) | |||
|     strategy = parseStrategy(straFile, stateToActions) | |||
| 
 | |||
|     if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) | |||
|     if plotting: plotter.plotScenario() | |||
|     passingStates = list() | |||
| 
 | |||
|     randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound) | |||
|     i = 0 | |||
|     print("Starting with random testing.") | |||
|     numQueries = 0 | |||
|     failureStates = list() | |||
|     while numQueries <= maxQueries: | |||
|         if i >= 500: | |||
|             if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6)) | |||
|             if plotting: plotter.takeScreenshot(iteration, prefix="random_testing") | |||
|             if plotting: plotter.turnCamera() | |||
|             if plotting: input("") | |||
|             print(f"{numQueries} {len(failureStates)} ") | |||
|             i = 0 | |||
|         testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest() | |||
|         i += queriesForThisTestCase | |||
|         numQueries += queriesForThisTestCase | |||
|         stateValuation = all_states[testCase] | |||
|         if testResult == Verdict.FAIL: | |||
|             failureStates.append(stateValuation) | |||
| 
 | |||
|     print(f"{numQueries} {len(failureStates)} ") | |||
| 
 | |||
| def parseArgs(): | |||
|     parser = argparse.ArgumentParser() | |||
|     parser.add_argument('--tra', type=str, required=True,  help='Path to .tra file.') | |||
|     parser.add_argument('--lab', type=str, required=True,  help='Path to .lab file.') | |||
|     parser.add_argument('--rew', type=str, required=True,  help='Path to .rew file.') | |||
|     parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.') | |||
| 
 | |||
|     refinement = parser.add_mutually_exclusive_group(required=True) | |||
|     refinement.add_argument('--refinement-steps', type=int,   default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.') | |||
|     refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.') | |||
| 
 | |||
|     parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.') | |||
|     parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.') | |||
| 
 | |||
|     random_testing = parser.add_mutually_exclusive_group() | |||
|     random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.") | |||
|     random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.') | |||
| 
 | |||
|     parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.') | |||
|     parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.') | |||
|     return parser.parse_args() | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     args = parseArgs() | |||
| 
 | |||
|     traFile = args.tra | |||
|     labFile = args.lab | |||
|     straFile = args.stra | |||
|     rewFile = args.rew | |||
| 
 | |||
|     ablationTesting = args.ablation | |||
|     plotting = args.plotting | |||
|     stepwisePlotting = args.stepwise | |||
| 
 | |||
|     maxQueriesForRandomTesting = args.random | |||
|     horizonBound = args.bound | |||
| 
 | |||
|     refinementSteps = args.refinement_steps | |||
|     refinementBound = args.refinement_bound | |||
| 
 | |||
|     if maxQueriesForRandomTesting == 0: #akward way to test for this... | |||
|         main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting) | |||
|     else: | |||
|         randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting) | |||
| @ -0,0 +1,135 @@ | |||
| #!/usr/bin/python3 | |||
| 
 | |||
| import sys | |||
| #import re | |||
| import json | |||
| import numpy as np | |||
| from collections import deque | |||
| 
 | |||
| from dataclasses import dataclass | |||
| 
 | |||
| 
 | |||
| @dataclass(frozen=False) | |||
| class StateAction: | |||
|     state_id: int | |||
|     action_id: int | |||
|     next_state_probabilities: dict | |||
|     action_name: str | |||
| 
 | |||
|     def normalizeDistribution(self): | |||
|         weight = np.sum([value for key, value in self.next_state_probabilities.items()]) | |||
|         self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() } | |||
| 
 | |||
| 
 | |||
| def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId): | |||
|     current_state_id = -1 | |||
|     current_action_id = -1 | |||
|     all_state_action_pairs = list() | |||
|     with open(traFile) as transitions: | |||
|         next(transitions) | |||
|         for line in transitions: | |||
|             line = line.replace("\n","") | |||
|             explode = line.split(" ") | |||
|             if len(explode) < 2: continue | |||
|             interval = json.loads(explode[3]) | |||
|             probability = (interval[0] + interval[1])/2 | |||
|             state_id = int(explode[0]) | |||
|             action_id = int(explode[1]) | |||
|             next_state_id = int(explode[2]) | |||
|             if len(explode) >= 5: | |||
|                 action_name = explode[4] | |||
|             else: | |||
|                 action_name = "" | |||
|             #print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.") | |||
|             if state_id in [0, 1, 2]: | |||
|                 continue | |||
| 
 | |||
|             next_state_probabilities = {next_state_id: probability} | |||
|             if current_state_id != state_id: | |||
|                 new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) | |||
|                 all_state_action_pairs.append(new_state_action_pair) | |||
|                 current_state_id = state_id | |||
|                 current_action_id = action_id | |||
| 
 | |||
|             elif current_action_id != action_id: | |||
|                 new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) | |||
|                 all_state_action_pairs.append(new_state_action_pair) | |||
|                 current_action_id = action_id | |||
|             else: | |||
|                 all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability | |||
| 
 | |||
|     # we need to sort the deadlock and reached states to insert them while building the .tra file | |||
|     deadlockStates = [(state, 0) for state in deadlockStates] | |||
|     reachedStates = [(state, maxStateId) for state in reachedStates] | |||
|     final_states = deadlockStates + reachedStates | |||
|     final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True)) | |||
| 
 | |||
|     with open("MDP_" + traFile, "w") as new_transitions_file: | |||
|         new_transitions_file.write("mdp\n") | |||
|         new_transitions_file.write(f"0 0 {maxStateId} 1.0\n") | |||
|         for entry in all_state_action_pairs: | |||
|             entry.normalizeDistribution() | |||
|             source_state = int(entry.state_id) | |||
|             while final_states and int(final_states[-1][0]) < source_state: | |||
|                 final_state = final_states.pop() | |||
|                 if int(final_state[0]) == 0: continue | |||
|                 new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n") | |||
|             for next_state_id, probability in entry.next_state_probabilities.items(): | |||
|                 new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n") | |||
|         new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n") | |||
| 
 | |||
|     state_to_actions = dict() | |||
|     for state_action_pair in all_state_action_pairs: | |||
|         if state_action_pair.state_id in state_to_actions: | |||
|             state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name) | |||
|         else: | |||
|             state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name] | |||
|     return state_to_actions, all_state_action_pairs | |||
| 
 | |||
| def readLabels(labFile): | |||
|     deadlockStates = list() | |||
|     reachedStates = list() | |||
|     with open(labFile) as states: | |||
|         newLabFile = "MDP_" + labFile | |||
|         newLabels = open(newLabFile, "w") | |||
|         optRewards = open(newLabFile + ".optrew", "w") | |||
|         safetyRewards = open(newLabFile + ".saferew", "w") | |||
|         labels = ["init", "deadlock", "reached", "failed"] | |||
|         next(states) | |||
|         newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n") | |||
|         maxStateId = -1 | |||
|         for line in states: | |||
|             line = line.replace(":","").replace("\n", "") | |||
|             explode = line.split(" ") | |||
|             newLabel = f"{explode[0]} " | |||
|             if int(explode[0]) > maxStateId: maxStateId = int(explode[0]) | |||
|             if int(explode[0]) == 0: | |||
|                 safetyRewards.write(f"{explode[0]} -100\n") | |||
|                 optRewards.write(f"{explode[0]} -100\n") | |||
|             #if "3" in explode: | |||
|             #    safetyRewards.write(f"{explode[0]} -100\n") | |||
|             #    optRewards.write(f"{explode[0]} -100\n") | |||
|             elif "2" in explode: | |||
|                 optRewards.write(f"{explode[0]} 100\n") | |||
|             else: | |||
|                 optRewards.write(f"{explode[0]} -1\n") | |||
|             for labelIndex in explode[1:]: | |||
|                 # sink states should not be deadlock states anymore: | |||
|                 if labelIndex == "1": | |||
|                     deadlockStates.append(int(explode[0])) | |||
|                     continue | |||
|                 if labelIndex == "2": | |||
|                     reachedStates.append(int(explode[0])) | |||
|                     continue | |||
|                 newLabel += f"{labels[int(labelIndex)]} " | |||
|             newLabels.write(newLabel + "\n") | |||
|         return deadlockStates, reachedStates, maxStateId + 1 | |||
| 
 | |||
| def main(traFile, labFile): | |||
|     deadlockStates, reachedStates, maxStateId = readLabels(labFile) | |||
|     translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     traFile = sys.argv[1] | |||
|     labFile = sys.argv[2] | |||
|     main(traFile, labFile) | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue