Browse Source

added method to just consider XY for restriction formulas

add_velocity_into_framework
sp 6 months ago
parent
commit
b30a3b8bf2
  1. 24
      rom_evaluate.py

24
rom_evaluate.py

@ -203,6 +203,20 @@ def clusterFormula(cluster):
return "(" + formulas[0] + ")" return "(" + formulas[0] + ")"
def clusterFormulaXY(cluster):
if len(cluster) == 0: return
formulas = set()
for state in cluster:
formulas.add(f"(x={state[0].x} & y={state[0].y})")
formulas = list(formulas)
while len(formulas) > 1:
formulas_tmp = [f"({formulas[i]} | {formulas[i+1]})" for i in range(0,len(formulas)//2)]
if len(formulas) % 2 == 1:
formulas_tmp.append(formulas[-1])
formulas = formulas_tmp
return "(" + formulas[0] + ")"
def clusterFormulaTrimmed(cluster): def clusterFormulaTrimmed(cluster):
formula = "" formula = ""
states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster] states = [(s[0].x,s[0].y, s[0].ski_position, s[0].velocity) for s in cluster]
@ -271,19 +285,19 @@ def createUnsafeFormula(clusters):
formulas = "" formulas = ""
indices = list() indices = list()
for i, cluster in enumerate(clusters): for i, cluster in enumerate(clusters):
formulas += f"formula Unsafe_{i} = {clusterFormulaTrimmed(cluster)};\n"
formulas += f"formula Unsafe_{i} = {clusterFormulaXY(cluster)};\n"
indices.append(f"Unsafe_{i}") indices.append(f"Unsafe_{i}")
return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe")# + label
return formulas + "\n" + createBalancedDisjunction(indices, "Unsafe") + label
def createSafeFormula(clusters): def createSafeFormula(clusters):
label = "label \"Safe\" = Safe;\n" label = "label \"Safe\" = Safe;\n"
formulas = "" formulas = ""
indices = list() indices = list()
for i, cluster in enumerate(clusters): for i, cluster in enumerate(clusters):
formulas += f"formula Safe_{i} = {clusterFormula(cluster)};\n"
formulas += f"formula Safe_{i} = {clusterFormulaXY(cluster)};\n"
indices.append(f"Safe_{i}") indices.append(f"Safe_{i}")
return formulas + "\n" + createBalancedDisjunction(indices, "Safe")# + label
return formulas + "\n" + createBalancedDisjunction(indices, "Safe") + label
def updatePrismFile(newFile, iteration, safeStates, unsafeStates): def updatePrismFile(newFile, iteration, safeStates, unsafeStates):
logger.info("Creating next prism file") logger.info("Creating next prism file")
@ -317,7 +331,7 @@ x = 70
nn_wrapper = SampleFactoryNNQueryWrapper() nn_wrapper = SampleFactoryNNQueryWrapper()
experiment_id = int(time.time()) experiment_id = int(time.time())
init_mdp = "velocity_safety"
init_mdp = "safety"
exec(f"mkdir -p images/testing_{experiment_id}", verbose=False) exec(f"mkdir -p images/testing_{experiment_id}", verbose=False)
markerSize = 1 markerSize = 1

Loading…
Cancel
Save