You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
376 lines
16 KiB
376 lines
16 KiB
#!/usr/bin/python3
|
|
|
|
import re, sys, os, shutil, fileinput, subprocess, argparse
|
|
from dataclasses import dataclass, field
|
|
import visvis as vv
|
|
import numpy as np
|
|
import argparse
|
|
|
|
from translate import translateTransitions, readLabels
|
|
from plotting import VisVisPlotter
|
|
from simulation import Simulator, Verdict
|
|
|
|
|
|
def convert(tuples):
|
|
return dict(tuples)
|
|
|
|
def getBasename(filename):
|
|
return os.path.basename(filename)
|
|
|
|
def traFileWithIteration(filename, iteration):
|
|
return os.path.splitext(filename)[0] + f"_{iteration:03}.tra"
|
|
|
|
def copyFile(filename, newFilename):
|
|
shutil.copy(filename, newFilename)
|
|
|
|
|
|
def execute(command, verbose=False):
|
|
if verbose: print(f"Executing {command}")
|
|
os.system(command)
|
|
|
|
@dataclass(frozen=True)
|
|
class State:
|
|
id: int
|
|
x: float
|
|
x_vel: float
|
|
y: float
|
|
y_vel: float
|
|
z: float
|
|
z_vel: float
|
|
|
|
def default_value():
|
|
return {'action' : None, 'choiceValue' : None}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class StateValue:
|
|
ranking: float
|
|
choices: dict = field(default_factory=default_value)
|
|
|
|
@dataclass(frozen=False)
|
|
class TestResult:
|
|
prob_pes_min: float
|
|
prob_pes_max: float
|
|
prob_pes_avg: float
|
|
prob_opt_min: float
|
|
prob_opt_max: float
|
|
prob_opt_avg: float
|
|
min_min: float
|
|
min_max: float
|
|
|
|
def csv(self, ws=" "):
|
|
return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}"
|
|
|
|
def parseStrategy(strategyFile, allStateActionPairs, time_index=0):
|
|
strategy = dict()
|
|
with open(strategyFile) as strategyLines:
|
|
for line in strategyLines:
|
|
line = line.replace("(","").replace(")","").replace("\n", "")
|
|
explode = re.split(",|=", line)
|
|
stateId = int(explode[0]) + 3
|
|
if stateId < 3: continue
|
|
if int(explode[1]) != time_index: continue
|
|
try:
|
|
strategy[stateId] = allStateActionPairs[stateId].index(explode[2])
|
|
except KeyError as e:
|
|
pass
|
|
|
|
return strategy
|
|
|
|
def queryStrategy(strategy, stateId):
|
|
try:
|
|
return strategy[stateId]
|
|
except:
|
|
return -1
|
|
|
|
|
|
def callTempest(files, reward, bound=3):
|
|
property_str = "!(\"failed\" | \"reached\")"
|
|
if True:
|
|
prop = f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );"
|
|
prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );"
|
|
else:
|
|
prop = f"filter(min, Rmin=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str} );"
|
|
prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str} );"
|
|
command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' "
|
|
|
|
results = list()
|
|
try:
|
|
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
|
|
for line in output:
|
|
if "Result" in line and not len(results) >= 10:
|
|
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
|
|
if range_value:
|
|
results.append(float(range_value.group(2)))
|
|
results.append(float(range_value.group(3)))
|
|
else:
|
|
value = re.search(r"(.*:)(.*)", line)
|
|
results.append(float(value.group(2)))
|
|
except subprocess.CalledProcessError as e:
|
|
print(e.output)
|
|
#results.append(-1)
|
|
#results.append(-1)
|
|
return TestResult(*(tuple(results)))
|
|
|
|
def parseRanking(filename, allStates):
|
|
state_ranking = dict()
|
|
try:
|
|
with open(filename, "r") as f:
|
|
filecontent = f.readlines()
|
|
for line in filecontent:
|
|
stateId = int(re.findall(r"^\d+", line)[0])
|
|
values = re.findall(r":(-?\d+\.?\d*),?", line)
|
|
ranking_value = float(values[0])
|
|
choices = {i : float(value) for i,value in enumerate(values[1:])}
|
|
state = allStates[stateId]
|
|
value = StateValue(ranking_value, choices)
|
|
state_ranking[state] = value
|
|
if len(state_ranking) == 0: return
|
|
all_values = [x.ranking for x in state_ranking.values()]
|
|
max_value = max(all_values)
|
|
min_value = min(all_values)
|
|
new_state_ranking = {}
|
|
for state, value in state_ranking.items():
|
|
choices = value.choices
|
|
try:
|
|
new_value = (value.ranking - min_value) / (max_value - min_value)
|
|
except ZeroDivisionError as e:
|
|
new_value = 0.0
|
|
new_state_ranking[state] = StateValue(new_value, choices)
|
|
state_ranking = new_state_ranking
|
|
except EnvironmentError:
|
|
print("TODO file not available. Exiting.")
|
|
sys.exit(1)
|
|
return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)}
|
|
|
|
def parseStateValuations(filename):
|
|
all_states = dict()
|
|
maxStateId = -1
|
|
for i in [0,1,2]:
|
|
dummy_values = [i] * 7
|
|
all_states[i] = State(*dummy_values)
|
|
with open(filename) as stateValuations:
|
|
for line in stateValuations:
|
|
values = re.findall(r"(-?\d+\.?\d*),?", line)
|
|
values = [int(values[0])] + [float(v) for v in values[1:]]
|
|
all_states[values[0]] = State(*values)
|
|
if values[0] > maxStateId: maxStateId = values[0]
|
|
dummy_values = [maxStateId + 1] * 7
|
|
all_states[maxStateId + 1] = State(*dummy_values)
|
|
return all_states
|
|
|
|
def parseResults(allStates):
|
|
state_to_values = dict()
|
|
with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer:
|
|
for max_line, min_line in zip(maximizer, minimizer):
|
|
max_values = re.findall(r"(-?\d+\.?\d*),?", max_line)
|
|
min_values = re.findall(r"(-?\d+\.?\d*),?", min_line)
|
|
if max_values[0] != min_values[0]:
|
|
print("min/max files do not match.")
|
|
assert(False)
|
|
stateId = int(max_values[0])
|
|
min_result = float(min_values[1])
|
|
max_result = float(max_values[1])
|
|
value = (min_result, max_result, max_result - min_result)
|
|
state_to_values[stateId] = value
|
|
return state_to_values
|
|
|
|
|
|
def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration):
|
|
stateIdRegex = re.compile(f"^{stateId}\s")
|
|
for line in fileinput.input(filename, inplace = True):
|
|
if not stateIdRegex.match(line):
|
|
print(line, end="")
|
|
else:
|
|
explode = line.split(" ")
|
|
if int(explode[1]) == chosenActionIndex:
|
|
print(line, end="")
|
|
|
|
def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration):
|
|
stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()]))
|
|
for line in fileinput.input(filename, inplace = True):
|
|
result = stateIdsRegex.match(line)
|
|
if not result:
|
|
print(line, end="")
|
|
else:
|
|
actionIndex = stateActionPairsToTrim[int(result[0])]
|
|
explode = line.split(" ")
|
|
if int(explode[1]) == actionIndex:
|
|
print(line, end="")
|
|
|
|
def getTopNStates(rankedStates, n, threshold):
|
|
if n != 0:
|
|
return dict(list(rankedStates.items())[-n:])
|
|
else:
|
|
return {state:value for state,value in rankedStates.items() if value.ranking >= threshold}
|
|
|
|
def getNRandomStates(rankedStates, n, testedStates):
|
|
stateIds = [state.id for state in rankedStates.keys()]
|
|
notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates])
|
|
if len(notYetTestedStates) >= n:
|
|
return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)]
|
|
else:
|
|
return notYetTestedStates
|
|
|
|
def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False):
|
|
|
|
all_states = parseStateValuations("MDP_state_valuations")
|
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile)
|
|
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
|
|
strategy = parseStrategy(straFile, stateToActions)
|
|
|
|
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
|
|
if plotting: plotter.plotScenario()
|
|
|
|
copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra")
|
|
|
|
iteration = 0
|
|
#testsPerIteration = refinementSteps
|
|
#refinementThreshold =
|
|
numTestedStates = 0
|
|
totalIterations = 60
|
|
|
|
testedStates = list()
|
|
while iteration < totalIterations:
|
|
print(f"{iteration:03}", end="\t")
|
|
sys.stdout.flush()
|
|
currentTraFile = traFileWithIteration("MDP_" + traFile, iteration)
|
|
nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1)
|
|
testResult = callTempest(f"{currentTraFile} MDP_{labFile}", "saferew", horizonBound)
|
|
state_ranking = parseRanking("action_ranking", all_states)
|
|
copyFile("action_ranking", f"action_ranking_{iteration:03}")
|
|
copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}")
|
|
copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}")
|
|
|
|
if not ablationTesting:
|
|
importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound)
|
|
statesToTest = [state.id for state in importantStates.keys()]
|
|
statesToPlot = importantStates
|
|
else:
|
|
statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates))
|
|
testedStates += statesToTest
|
|
statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest}
|
|
|
|
|
|
copyFile(currentTraFile, nextTraFile)
|
|
stateActionPairsToTrim = dict()
|
|
for testState in statesToTest:
|
|
chosenActionIndex = queryStrategy(strategy, testState)
|
|
if chosenActionIndex != -1:
|
|
stateActionPairsToTrim[testState] = chosenActionIndex
|
|
stateEstimates = parseResults(all_states)
|
|
results = [0,0,0]
|
|
|
|
failureStates = list()
|
|
validatedStates = list()
|
|
for state, estimates in stateEstimates.items():
|
|
if state in deadlockStates or state in reachedStates:
|
|
continue
|
|
if estimates[0] > 0.05:
|
|
results[1] += 1
|
|
failureStates.append(all_states[state])
|
|
#print(f"{state}: {estimates}")
|
|
elif estimates[1] <= 0.05:
|
|
results[0] += 1
|
|
validatedStates.append(all_states[state])
|
|
else:
|
|
results[2] += 1
|
|
|
|
|
|
removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration)
|
|
print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}")
|
|
if results[2] == 0:
|
|
sys.exit(0)
|
|
numTestedStates += len(statesToTest)
|
|
iteration += 1
|
|
|
|
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True)
|
|
if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6))
|
|
if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05")
|
|
|
|
def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False):
|
|
all_states = parseStateValuations("MDP_state_valuations")
|
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile)
|
|
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
|
|
strategy = parseStrategy(straFile, stateToActions)
|
|
|
|
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting)
|
|
if plotting: plotter.plotScenario()
|
|
passingStates = list()
|
|
|
|
randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound)
|
|
i = 0
|
|
print("Starting with random testing.")
|
|
numQueries = 0
|
|
failureStates = list()
|
|
while numQueries <= maxQueries:
|
|
if i >= 500:
|
|
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6))
|
|
if plotting: plotter.takeScreenshot(iteration, prefix="random_testing")
|
|
if plotting: plotter.turnCamera()
|
|
if plotting: input("")
|
|
print(f"{numQueries} {len(failureStates)} ")
|
|
i = 0
|
|
testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest()
|
|
i += queriesForThisTestCase
|
|
numQueries += queriesForThisTestCase
|
|
stateValuation = all_states[testCase]
|
|
if testResult == Verdict.FAIL:
|
|
failureStates.append(stateValuation)
|
|
|
|
print(f"{numQueries} {len(failureStates)} ")
|
|
|
|
def parseArgs():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--tra', type=str, required=True, help='Path to .tra file.')
|
|
parser.add_argument('--lab', type=str, required=True, help='Path to .lab file.')
|
|
parser.add_argument('--rew', type=str, required=True, help='Path to .rew file.')
|
|
parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.')
|
|
|
|
refinement = parser.add_mutually_exclusive_group(required=True)
|
|
refinement.add_argument('--refinement-steps', type=int, default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.')
|
|
refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.')
|
|
|
|
parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.')
|
|
parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.')
|
|
|
|
random_testing = parser.add_mutually_exclusive_group()
|
|
random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.")
|
|
random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.')
|
|
|
|
parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.')
|
|
parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.')
|
|
return parser.parse_args()
|
|
|
|
if __name__ == '__main__':
|
|
args = parseArgs()
|
|
|
|
traFile = args.tra
|
|
labFile = args.lab
|
|
straFile = args.stra
|
|
rewFile = args.rew
|
|
|
|
ablationTesting = args.ablation
|
|
plotting = args.plotting
|
|
stepwisePlotting = args.stepwise
|
|
|
|
maxQueriesForRandomTesting = args.random
|
|
horizonBound = args.bound
|
|
|
|
refinementSteps = args.refinement_steps
|
|
refinementBound = args.refinement_bound
|
|
|
|
if maxQueriesForRandomTesting == 0: #akward way to test for this...
|
|
main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting)
|
|
else:
|
|
randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting)
|