Browse Source

more changes...

add_velocity_into_framework
sp 8 months ago
parent
commit
6fce864656
  1. 49
      rom_evaluate.py

49
rom_evaluate.py

@ -77,6 +77,9 @@ class TestResult:
init_check_opt_min: float init_check_opt_min: float
init_check_opt_max: float init_check_opt_max: float
init_check_opt_avg: float init_check_opt_avg: float
safe_states: int
unsafe_states: int
policy_queries: int
def __str__(self): def __str__(self):
return f"""Test Result: return f"""Test Result:
init_check_pes_min: {self.init_check_pes_min} init_check_pes_min: {self.init_check_pes_min}
@ -87,7 +90,7 @@ class TestResult:
init_check_opt_avg: {self.init_check_opt_avg} init_check_opt_avg: {self.init_check_opt_avg}
""" """
def csv(self, ws=" "): def csv(self, ws=" "):
return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}"
return f"{self.init_check_pes_min:0.04f}{ws}{self.init_check_pes_max:0.04f}{ws}{self.init_check_pes_avg:0.04f}{ws}{self.init_check_opt_min:0.04f}{ws}{self.init_check_opt_max:0.04f}{ws}{self.init_check_opt_avg:0.04f}{ws}{self.safeStates}{ws}{self.unsafeStates}{ws}{self.policy_queries}"
def exec(command,verbose=True): def exec(command,verbose=True):
if verbose: print(f"Executing {command}") if verbose: print(f"Executing {command}")
@ -128,10 +131,9 @@ def saveObservations(observations, verdict, testDir):
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50): def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
#def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=50):
#print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="") #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
testDir = f"{x}_{y}_{ski_position}_{velocity}" testDir = f"{x}_{y}_{ski_position}_{velocity}"
#testDir = f"{x}_{y}_{ski_position}"
try:
for i, r in enumerate(ramDICT[y]): for i, r in enumerate(ramDICT[y]):
ale.setRAM(i,r) ale.setRAM(i,r)
ski_position_setting = ski_position_counter[ski_position] ski_position_setting = ski_position_counter[ski_position]
@ -140,7 +142,12 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
ale.setRAM(14,0) ale.setRAM(14,0)
ale.setRAM(25,x) ale.setRAM(25,x)
ale.setRAM(14,180) # TODO ale.setRAM(14,180) # TODO
except Exception as e:
print(e)
logger.warn(f"Could not run test for x: {x}, y: {y}, ski_position: {ski_position}, velocity: {velocity}")
return (Verdict.INCONCLUSIVE, 0)
num_queries = 0
all_obs = list() all_obs = list()
speed_list = list() speed_list = list()
resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA) resized_obs = cv2.resize(ale.getScreenGrayscale(), (84,84), interpolation=cv2.INTER_AREA)
@ -152,32 +159,33 @@ def run_single_test(ale, nn_wrapper, x,y,ski_position, velocity, duration=50):
if i % 4 == 0: if i % 4 == 0:
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
action = nn_wrapper.query(stack_tensor) action = nn_wrapper.query(stack_tensor)
num_queries += 1
ale.act(input_to_action(str(action))) ale.act(input_to_action(str(action)))
else: else:
ale.act(input_to_action(str(action))) ale.act(input_to_action(str(action)))
speed_list.append(ale.getRAM()[14]) speed_list.append(ale.getRAM()[14])
if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0: if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
#saveObservations(all_obs, Verdict.BAD, testDir) #saveObservations(all_obs, Verdict.BAD, testDir)
return Verdict.BAD
return (Verdict.BAD, num_queries)
#saveObservations(all_obs, Verdict.GOOD, testDir) #saveObservations(all_obs, Verdict.GOOD, testDir)
return Verdict.GOOD
return (Verdict.GOOD, num_queries)
def computeStateRanking(mdp_file, iteration): def computeStateRanking(mdp_file, iteration):
logger.info("Computing state ranking") logger.info("Computing state ranking")
tic() tic()
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], !\"Hit_Tree\" & !\"Hit_Gate\" & !\"Unsafe\" );"
prop = f"filter(min, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(max, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(avg, Pmin=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(min, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(max, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += f"filter(avg, Pmax=? [ G !(\"Hit_Tree\" | \"Hit_Gate\" | \"Unsafe\") ], (!\"S_Hit_Tree\" & !\"S_Hit_Gate\") | (\"Safe\" | \"Unsafe\") );"
prop += 'Rmax=? [C <= 200]' prop += 'Rmax=? [C <= 200]'
results = list() results = list()
try: try:
command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'" command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --build-all-labels --prop '{prop}'"
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n')
for line in output: for line in output:
print(line)
#print(line)
if "Result" in line and not len(results) >= 6: if "Result" in line and not len(results) >= 6:
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line)
if range_value: if range_value:
@ -191,7 +199,7 @@ def computeStateRanking(mdp_file, iteration):
# todo die gracefully if ranking is uniform # todo die gracefully if ranking is uniform
print(e.output) print(e.output)
logger.info(f"Computing state ranking - DONE: took {toc()} seconds") logger.info(f"Computing state ranking - DONE: took {toc()} seconds")
return TestResult(*tuple(results))
return TestResult(*tuple(results),0,0,0)
def fillStateRanking(file_name, match=""): def fillStateRanking(file_name, match=""):
logger.info(f"Parsing state ranking, {file_name}") logger.info(f"Parsing state ranking, {file_name}")
@ -274,6 +282,8 @@ def statesFormulaTrimmed(states):
formula += ")" formula += ")"
formula += ")" formula += ")"
formulas.append(formula) formulas.append(formula)
print(formulas)
sys.exit(1)
return createBalancedDisjunction(formulas) return createBalancedDisjunction(formulas)
# https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list # https://stackoverflow.com/questions/5389507/iterating-over-every-two-elements-in-a-list
@ -459,8 +469,8 @@ if __name__ == '__main__':
_init_logger() _init_logger()
logger = logging.getLogger('main') logger = logging.getLogger('main')
logger.info("Starting") logger.info("Starting")
n_clusters = 40
testAll = False testAll = False
num_queries = 0
source = "images/1_full_scaled_down.png" source = "images/1_full_scaled_down.png"
for ski_position in range(1, num_ski_positions + 1): for ski_position in range(1, num_ski_positions + 1):
@ -477,6 +487,10 @@ if __name__ == '__main__':
while True: while True:
updatePrismFile(init_mdp, iteration, safeStates, unsafeStates) updatePrismFile(init_mdp, iteration, safeStates, unsafeStates)
modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration) modelCheckingResult = computeStateRanking(f"{init_mdp}_{iteration:03}.prism", iteration)
if len(results) > 0:
modelCheckingResult.safeStates = results[-1].safeStates
modelCheckingResult.unsafeStates = results[-1].unsafeStates
modelCheckingResult.num_queries = results[-1].num_queries
results.append(modelCheckingResult) results.append(modelCheckingResult)
logger.info(f"Model Checking Result: {modelCheckingResult}") logger.info(f"Model Checking Result: {modelCheckingResult}")
if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps: if abs(modelCheckingResult.init_check_pes_avg - modelCheckingResult.init_check_opt_avg) < eps:
@ -508,8 +522,8 @@ if __name__ == '__main__':
y = state[0].y y = state[0].y
ski_pos = state[0].ski_position ski_pos = state[0].ski_position
velocity = state[0].velocity velocity = state[0].velocity
#result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
result = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
result, num_queries_this_test_case = run_single_test(ale,nn_wrapper,x,y,ski_pos, velocity, duration=50)
num_queries += num_queries_this_test_case
if result == Verdict.BAD: if result == Verdict.BAD:
if testAll: if testAll:
failingPerCluster[id].append(state) failingPerCluster[id].append(state)
@ -522,6 +536,9 @@ if __name__ == '__main__':
clusterResult[id] = (cluster, Verdict.GOOD) clusterResult[id] = (cluster, Verdict.GOOD)
safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]) safeStates.update([(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster])
logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}") logger.info(f"Iteration: {iteration:03}\t-\tSafe Results : {len(safeStates)}\t-\tUnsafe Results:{len(unsafeStates)}")
results[-1].safeStates = len(safeStates)
results[-1].unsafeStates = len(unsafeStates)
results[-1].num_queries = num_queries
if testAll: drawClusters(failingPerCluster, f"failing", iteration) if testAll: drawClusters(failingPerCluster, f"failing", iteration)
drawResult(clusterResult, "result", iteration) drawResult(clusterResult, "result", iteration)

Loading…
Cancel
Save