Browse Source

updated formula computation YRanges

add_velocity_into_framework
sp 6 months ago
parent
commit
bfa808983b
  1. 103
      rom_evaluate.py

103
rom_evaluate.py

@ -170,22 +170,32 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
#saveObservations(all_obs, Verdict.GOOD, testDir)
return (Verdict.GOOD, num_queries)
def skiPositionFormulaList(name):
formulas = list()
for i in range(1, num_ski_positions+1):
formulas.append(f"\"{name}_{i}\"")
return createBalancedDisjunction(formulas)
def computeStateRanking(mdp_file, iteration):
logger.info("Computing state ranking")
tic()
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | {skiPositionFormulaList('Unsafe')}) ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | ({skiPositionFormulaList('Safe')} | {skiPositionFormulaList('Unsafe')}) );"
prop += 'Rmax=? [C <= 200]'
results = list()
try:
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
num_states = 0
for line in output:
#print(line)
if "States:" in line:
num_states = int(line.split(" ")[-1])
if "Result" in line and not len(results) >= 6:
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
if range_value:
@ -199,7 +209,7 @@ def computeStateRanking(mdp_file, iteration):
# todo die gracefully if ranking is uniform
print(e.output)
logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
return TestResult(*tuple(results),0,0,0)
return TestResult(*tuple(results),0,0,0), num_states
def fillStateRanking(file_name, match=""):
logger.info(f"Parsing state ranking, {file_name}")
@ -214,13 +224,9 @@ def fillStateRanking(file_name, match=""):
if ranking_value <= 0.1:
continue
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
#print("stateMapping", stateMapping)
choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
choices = {key:float(value) for (key,value) in choices.items()}
#print("choices", choices)
#print("ranking_value", ranking_value)
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]), int(stateMapping["velocity"])//2)
#state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
value = StateValue(ranking_value, choices)
state_ranking[state] = value
logger.info(f"Parsing state ranking - DONE: took {toc()} seconds")
@ -235,8 +241,7 @@ def fillStateRanking(file_name, match=""):
def createDisjunction(formulas):
return " | ".join(formulas)
def statesFormulaTrimmed(states):
if len(states) == 0: return "false"
def statesFormulaTrimmed(states, name):
#states = [(s[0].x,s[0].y, s[0].ski_position) for s in cluster]
skiPositionGroup = defaultdict(list)
for item in states:
@ -244,7 +249,7 @@ def statesFormulaTrimmed(states):
formulas = list()
for skiPosition, skiPos_group in skiPositionGroup.items():
formula = f"( ski_position={skiPosition} & "
formula = f"formula {name}_{skiPosition} = ( ski_position={skiPosition} & "
firstVelocity = True
velocityGroup = defaultdict(list)
for item in skiPos_group:
@ -254,18 +259,16 @@ def statesFormulaTrimmed(states):
firstVelocity = False
else:
formula += " | "
formulasPerSkiPosition = list()
formula += f" (velocity={velocity} & "
firstY = True
yPosGroup = defaultdict(list)
yAndXRanges = dict()
for item in velocity_group:
yPosGroup[item[1]].append(item)
for y, y_group in yPosGroup.items():
if firstY:
firstY = False
else:
formula += " | "
sorted_y_group = sorted(y_group, key=lambda s: s[0])
formula += f"( y={y} & ("
#formula += f"( y={y} & ("
current_x_min = sorted_y_group[0][0]
current_x = sorted_y_group[0][0]
x_ranges = list()
@ -277,14 +280,43 @@ def statesFormulaTrimmed(states):
current_x_min = state[0]
current_x = state[0]
x_ranges.append(f" ({current_x_min}<=x&x<={sorted_y_group[-1][0]})")
formula += " | ".join(x_ranges)
formula += ") )"
formula += ")"
xRangesDisjunction = createBalancedDisjunction(x_ranges)
if xRangesDisjunction in yAndXRanges:
yAndXRanges[xRangesDisjunction].append(y)
else:
yAndXRanges[xRangesDisjunction] = list()
yAndXRanges[xRangesDisjunction].append(y)
for xRange, ys in yAndXRanges.items():
#if firstY:
# firstY = False
#else:
# formula += " | "
sorted_ys = sorted(ys)
if len(ys) == 1:
formulasPerSkiPosition.append(f" ({xRange} & y={ys[0]})")
continue
current_y_min = sorted_ys[0]
current_y = sorted_ys[0]
y_ranges = list()
for y in sorted_ys[1:]:
if y - current_y == 2:
current_y = y
elif abs(y - current_y) > 2:
y_ranges.append(f" ({current_y_min}<=y&y<={current_y})")
current_y_min = y
current_y = y
y_ranges.append(f" ({current_y_min}<=y&y<={sorted_ys[-1]})")
formulasPerSkiPosition.append(f" ({xRange} & ({createBalancedDisjunction(y_ranges)}))")
formula += createBalancedDisjunction(formulasPerSkiPosition)
formula += ")"
formula += ");"
formulas.append(formula)
print(formulas)
sys.exit(1)
return createBalancedDisjunction(formulas)
for i in range(1, num_ski_positions+1):
if i in skiPositionGroup:
continue
formulas.append(f"formula {name}_{i} = false;")
return "\n".join(formulas) + "\n"
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
def pairwise(iterable):
@ -309,10 +341,11 @@ def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
newFile = f"{newFile}_{iteration:03}.prism"
exec(f"cp {initFile} {newFile}", verbose=False)
with open(newFile, "a") as prism:
prism.write(f"formula Safe = {statesFormulaTrimmed(safeStates)};\n")
prism.write(f"formula Unsafe = {statesFormulaTrimmed(unsafeStates)};\n")
prism.write(f"label \"Safe\" = Safe;\n")
prism.write(f"label \"Unsafe\" = Unsafe;\n")
prism.write(statesFormulaTrimmed(safeStates, "Safe"))
prism.write(statesFormulaTrimmed(unsafeStates, "Unsafe"))
for i in range(1,num_ski_positions+1):
prism.write(f"label \"Safe_{i}\" = Safe_{i};\n")
prism.write(f"label \"Unsafe_{i}\" = Unsafe_{i};\n")
logger.info(f"Creating next prism file - DONE: took {toc()} seconds")
@ -486,11 +519,11 @@ if __name__ == '__main__':
eps = 0.1
while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
modelCheckingResult, numStates = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
if len(results) > 0:
modelCheckingResult.safeStates = results[-1].safeStates
modelCheckingResult.unsafeStates = results[-1].unsafeStates
modelCheckingResult.num_queries = results[-1].num_queries
modelCheckingResult.policy_queries = results[-1].policy_queries
results.append(modelCheckingResult)
logger.info(f"Model Checking Result: {modelCheckingResult}")
if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
@ -538,7 +571,13 @@ if __name__ == '__main__':
logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}")
results[-1].safeStates = len(safeStates)
results[-1].unsafeStates = len(unsafeStates)
results[-1].num_queries = num_queries
results[-1].policy_queries = num_queries
# Account for self-loop states after first iteration
if iteration > 0:
results[-1].init_check_pes_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_pes_avg*numStates + 1.0*results[-2].unsafeStates + 0.0*results[-2].safeStates)
results[-1].init_check_opt_avg = 1/(numStates+len(safeStates)+len(unsafeStates)) * (results[-1].init_check_opt_avg*numStates + 0.0*results[-2].unsafeStates + 1.0*results[-2].safeStates)
for result in results:
print(result.csv())
if testAll: drawClusters(failingPerCluster, f"failing", iteration)
drawResult(clusterResult, "result", iteration)

Loading…
Cancel
Save