|
|
@ -170,22 +170,32 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): |
|
|
|
#saveObservations(all_obs, Verdict.GOOD, testDir) |
|
|
|
return (Verdict.GOOD, num_queries) |
|
|
|
|
|
|
|
def skiPositionFormulaList(name): |
|
|
|
formulas = list() |
|
|
|
for i in range(1, num_ski_positions+1): |
|
|
|
formulas.append(f"\"{name}_{i}\"") |
|
|
|
return createBalancedDisjunction(formulas) |
|
|
|
|
|
|
|
|
|
|
|
def computeStateRanking(mdp_file, iteration): |
|
|
|
logger.info("Computing state ranking") |
|
|
|
tic() |
|
|
|
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );" |
|
|
|
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );" |
|
|
|
prop += 'Rmax=? [C <= 200]' |
|
|
|
results = list() |
|
|
|
try: |
|
|
|
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" |
|
|
|
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') |
|
|
|
num_states = 0 |
|
|
|
for line in output: |
|
|
|
#print(line) |
|
|
|
if "States:" in line: |
|
|
|
num_states = int(line.split(" ")[-1]) |
|
|
|
if "Result" in line and not len(results) >= 6: |
|
|
|
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) |
|
|
|
if range_value: |
|
|
@ -199,7 +209,7 @@ def computeStateRanking(mdp_file, iteration): |
|
|
|
# todo die gracefully if ranking is uniform |
|
|
|
print(e.output) |
|
|
|
logger.info(f"Computing state ranking - DONE: took {toc()} seconds") |
|
|
|
return TestResult(*tuple(results),0,0,0) |
|
|
|
return TestResult(*tuple(results),0,0,0), num_states |
|
|
|
|
|
|
|
def fillStateRanking(file_name, match=""): |
|
|
|
logger.info(f"Parsing state ranking, {file_name}") |
|
|
@ -214,13 +224,9 @@ def fillStateRanking(file_name, match=""): |
|
|
|
if ranking_value <= 0.1: |
|
|
|
continue |
|
|
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
|
|
|
#print("stateMapping", stateMapping) |
|
|
|
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line)) |
|
|
|
choices = {key:float(value) for (key,value) in choices.items()} |
|
|
|
#print("choices", choices) |
|
|
|
#print("ranking_value", ranking_value) |
|
|
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2) |
|
|
|
#state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
|
|
|
value = StateValue(ranking_value, choices) |
|
|
|
state_ranking[state] = value |
|
|
|
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds") |
|
|
@ -235,8 +241,7 @@ def fillStateRanking(file_name, match=""): |
|
|
|
def createDisjunction(formulas): |
|
|
|
return " | ".join(formulas) |
|
|
|
|
|
|
|
def statesFormulaTrimmed(states): |
|
|
|
if len(states) == 0: return "false" |
|
|
|
def statesFormulaTrimmed(states, name): |
|
|
|
#states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster] |
|
|
|
skiPositionGroup = defaultdict(list) |
|
|
|
for item in states: |
|
|
@ -244,7 +249,7 @@ def statesFormulaTrimmed(states): |
|
|
|
|
|
|
|
formulas = list() |
|
|
|
for skiPosition, skiPos_group in skiPositionGroup.items(): |
|
|
|
formula = f"( ski_position={skiPosition} & " |
|
|
|
formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & " |
|
|
|
firstVelocity = True |
|
|
|
velocityGroup = defaultdict(list) |
|
|
|
for item in skiPos_group: |
|
|
@ -254,18 +259,16 @@ def statesFormulaTrimmed(states): |
|
|
|
firstVelocity = False |
|
|
|
else: |
|
|
|
formula += " | " |
|
|
|
formulasPerSkiPosition = list() |
|
|
|
formula += f" (velocity={velocity} & " |
|
|
|
firstY = True |
|
|
|
yPosGroup = defaultdict(list) |
|
|
|
yAndXRanges = dict() |
|
|
|
for item in velocity_group: |
|
|
|
yPosGroup[item[1]].append(item) |
|
|
|
for y, y_group in yPosGroup.items(): |
|
|
|
if firstY: |
|
|
|
firstY = False |
|
|
|
else: |
|
|
|
formula += " | " |
|
|
|
sorted_y_group = sorted(y_group, key=lambda s: s[0]) |
|
|
|
formula += f"( y={y} & (" |
|
|
|
#formula += f"( y={y} & (" |
|
|
|
current_x_min = sorted_y_group[0][0] |
|
|
|
current_x = sorted_y_group[0][0] |
|
|
|
x_ranges = list() |
|
|
@ -273,18 +276,47 @@ def statesFormulaTrimmed(states): |
|
|
|
if state[0] - current_x == 1: |
|
|
|
current_x = state[0] |
|
|
|
else: |
|
|
|
x_ranges.append(f" ({current_x_min}<= x & x<={current_x})") |
|
|
|
x_ranges.append(f" ({current_x_min}<=x&x<={current_x})") |
|
|
|
current_x_min = state[0] |
|
|
|
current_x = state[0] |
|
|
|
x_ranges.append(f" ({current_x_min}<= x & x<={sorted_y_group[-1][0]})") |
|
|
|
formula += " | ".join(x_ranges) |
|
|
|
formula += ") )" |
|
|
|
x_ranges.append(f" ({current_x_min}<=x&x<={sorted_y_group[-1][0]})") |
|
|
|
xRangesDisjunction = createBalancedDisjunction(x_ranges) |
|
|
|
if xRangesDisjunction in yAndXRanges: |
|
|
|
yAndXRanges[xRangesDisjunction].append(y) |
|
|
|
else: |
|
|
|
yAndXRanges[xRangesDisjunction] = list() |
|
|
|
yAndXRanges[xRangesDisjunction].append(y) |
|
|
|
for xRange, ys in yAndXRanges.items(): |
|
|
|
#if firstY: |
|
|
|
# firstY = False |
|
|
|
#else: |
|
|
|
# formula += " | " |
|
|
|
sorted_ys = sorted(ys) |
|
|
|
if len(ys) == 1: |
|
|
|
formulasPerSkiPosition.append(f" ({xRange} & y={ys[0]})") |
|
|
|
continue |
|
|
|
current_y_min = sorted_ys[0] |
|
|
|
current_y = sorted_ys[0] |
|
|
|
y_ranges = list() |
|
|
|
for y in sorted_ys[1:]: |
|
|
|
if y - current_y == 2: |
|
|
|
current_y = y |
|
|
|
elif abs(y - current_y) > 2: |
|
|
|
y_ranges.append(f" ({current_y_min}<=y&y<={current_y})") |
|
|
|
current_y_min = y |
|
|
|
current_y = y |
|
|
|
y_ranges.append(f" ({current_y_min}<=y&y<={sorted_ys[-1]})") |
|
|
|
formulasPerSkiPosition.append(f" ({xRange} & ({createBalancedDisjunction(y_ranges)}))") |
|
|
|
|
|
|
|
formula += createBalancedDisjunction(formulasPerSkiPosition) |
|
|
|
formula += ")" |
|
|
|
formula += ")" |
|
|
|
formula += ");" |
|
|
|
formulas.append(formula) |
|
|
|
print(formulas) |
|
|
|
sys.exit(1) |
|
|
|
return createBalancedDisjunction(formulas) |
|
|
|
for i in range(1, num_ski_positions+1): |
|
|
|
if i in skiPositionGroup: |
|
|
|
continue |
|
|
|
formulas.append(f"formula {name}_{i} = false;") |
|
|
|
return "\n".join(formulas) + "\n" |
|
|
|
|
|
|
|
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list |
|
|
|
def pairwise(iterable): |
|
|
@ -309,10 +341,11 @@ def updatePrismFile(newFile, iteration, safeStates, unsafeStates): |
|
|
|
newFile = f"{newFile}_{iteration:03}.prism" |
|
|
|
exec(f"cp {initFile} {newFile}", verbose=False) |
|
|
|
with open(newFile, "a") as prism: |
|
|
|
prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n") |
|
|
|
prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n") |
|
|
|
prism.write(f"label \"Safe\" = Safe;\n") |
|
|
|
prism.write(f"label \"Unsafe\" = Unsafe;\n") |
|
|
|
prism.write(statesFormulaTrimmed(safeStates, "Safe")) |
|
|
|
prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe")) |
|
|
|
for i in range(1,num_ski_positions+1): |
|
|
|
prism.write(f"label \"Safe_{i}\" = Safe_{i};\n") |
|
|
|
prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n") |
|
|
|
|
|
|
|
logger.info(f"Creating next prism file - DONE: took {toc()} seconds") |
|
|
|
|
|
|
@ -486,11 +519,11 @@ if __name__ == '__main__': |
|
|
|
eps = 0.1 |
|
|
|
while True: |
|
|
|
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) |
|
|
|
modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) |
|
|
|
modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) |
|
|
|
if len(results) > 0: |
|
|
|
modelCheckingResult.safeStates = results[-1].safeStates |
|
|
|
modelCheckingResult.unsafeStates = results[-1].unsafeStates |
|
|
|
modelCheckingResult.num_queries = results[-1].num_queries |
|
|
|
modelCheckingResult.safeStates = results[-1].safeStates |
|
|
|
modelCheckingResult.unsafeStates = results[-1].unsafeStates |
|
|
|
modelCheckingResult.policy_queries = results[-1].policy_queries |
|
|
|
results.append(modelCheckingResult) |
|
|
|
logger.info(f"Model Checking Result: {modelCheckingResult}") |
|
|
|
if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps: |
|
|
@ -538,7 +571,13 @@ if __name__ == '__main__': |
|
|
|
logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}") |
|
|
|
results[-1].safeStates = len(safeStates) |
|
|
|
results[-1].unsafeStates = len(unsafeStates) |
|
|
|
results[-1].num_queries = num_queries |
|
|
|
results[-1].policy_queries = num_queries |
|
|
|
# Account for self-loop states after first iteration |
|
|
|
if iteration > 0: |
|
|
|
results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafeStates + 0.0*results[-2].safeStates) |
|
|
|
results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafeStates + 1.0*results[-2].safeStates) |
|
|
|
for result in results: |
|
|
|
print(result.csv()) |
|
|
|
if testAll: drawClusters(failingPerCluster, f"failing", iteration) |
|
|
|
|
|
|
|
drawResult(clusterResult, "result", iteration) |
|
|
|