sp
11 months ago
commit
b3be734b8f
4 changed files with 667 additions and 0 deletions
-
88plotting.py
-
68simulation.py
-
376test_model.py
-
135translate.py
@ -0,0 +1,88 @@ |
|||
#!/usr/bin/python3 |
|||
|
|||
import visvis as vv |
|||
import numpy as np |
|||
|
|||
import time |
|||
|
|||
def translateArrayToVV(array): |
|||
scaling = (float(array[0][1]) - float(array[0][0]), float(array[2][1]) - float(array[2][0]), float(array[4][1]) - float(array[4][0])) |
|||
translation = (float(array[0][0]) + scaling[0] * 0.5, float(array[2][0]) + scaling[0] * 0.5, float(array[4][0]) + scaling[0] * 0.5) |
|||
return translation, scaling |
|||
|
|||
class VisVisPlotter(): |
|||
def __init__(self, stateValuations, goalStates, deadlockStates, stepwise): |
|||
self.app = vv.use() |
|||
self.fig = vv.clf() |
|||
self.ax = vv.cla() |
|||
self.stateColor = (0,0,0.75,0.4) |
|||
self.failColor = (0.8,0.8,0.8,1.0) # (1,0,0,0.8) |
|||
self.goalColor = (0,1,0,1.0) |
|||
self.stateScaling = (2,2,2) #(.5,.5,.5) |
|||
|
|||
self.plotedStates = set() |
|||
self.plotedMeshes = set() |
|||
self.stepwise = stepwise |
|||
|
|||
auxStates = [0,1,2,25518] |
|||
self.goals = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in goalStates and not stateId in auxStates]) |
|||
self.fails = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in deadlockStates and not stateId in auxStates]) |
|||
|
|||
def plotScenario(self, saveScreenshot=True): |
|||
for goal in self.goals: |
|||
state = vv.solidBox(goal, scaling=self.stateScaling) |
|||
state.faceColor = self.goalColor |
|||
for fail in self.fails: |
|||
state = vv.solidBox(fail, scaling=self.stateScaling) |
|||
state.faceColor = self.failColor |
|||
|
|||
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
|||
self.ax.SetView({'zoom':0.025, 'elevation':20, 'azimuth':30}) |
|||
if saveScreenshot: vv.screenshot("000.png", sf=3, bg='w', ob=vv.gcf()) |
|||
|
|||
def run(self): |
|||
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
|||
self.app.Run() |
|||
|
|||
def clear(self): |
|||
axes = vv.gca() |
|||
axes.Clear() |
|||
|
|||
def plotStates(self, states, coloring="", removeMeshes=False): |
|||
if self.stepwise and removeMeshes: |
|||
self.clear() |
|||
self.plotScenario(saveScreenshot=False) |
|||
if not coloring: |
|||
coloring = self.stateColor |
|||
plotedRegions = set() |
|||
for state in states: |
|||
if state in self.plotedStates: continue # what are the implications of this? |
|||
coordinates = (state.x, state.y, state.z) |
|||
print(f"plotting {state} at {coordinates}") |
|||
if coordinates in plotedRegions: continue |
|||
state = vv.solidBox(coordinates, scaling=(1,1,1))#(0.5,0.5,0.5)) |
|||
state.faceColor = coloring |
|||
plotedRegions.add(coordinates) |
|||
if self.stepwise: |
|||
self.plotedMeshes.add(state) |
|||
self.plotedStates.update(states) |
|||
|
|||
def takeScreenshot(self, iteration, prefix=""): |
|||
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
|||
#config = [(90, 00), (45, 00), (0, 00), (-45, 00), (-90, 00)] |
|||
config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] |
|||
for elevation, azimuth in config: |
|||
filename = f"{prefix}{'_' if prefix else ''}{iteration:03}_{elevation}_{azimuth}.png" |
|||
print(f"Saving Screenshot to {filename}") |
|||
self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) |
|||
vv.screenshot(filename, sf=3, bg='w', ob=vv.gcf()) |
|||
|
|||
def turnCamera(self): |
|||
config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] |
|||
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
|||
for elevation, azimuth in config: |
|||
self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
main() |
@ -0,0 +1,68 @@ |
|||
import sys |
|||
from enum import Flag, auto |
|||
|
|||
import numpy as np |
|||
|
|||
|
|||
class Verdict(Flag): |
|||
INCONCLUSIVE = auto() |
|||
PASS = auto() |
|||
FAIL = auto() |
|||
|
|||
|
|||
|
|||
class Simulator(): |
|||
def __init__(self, allStateActionPairs, strategy, deadlockStates, reachedStates, bound=3, numSimulations=1): |
|||
self.allStateActionPairs = { ( pair.state_id, pair.action_id ) : pair.next_state_probabilities for pair in allStateActionPairs } |
|||
self.strategy = strategy |
|||
self.deadlockStates = deadlockStates |
|||
self.reachedStates = reachedStates |
|||
|
|||
#print(f"Deadlock: {self.deadlockStates}") |
|||
#print(f"GoalStates: {self.reachedStates}") |
|||
|
|||
|
|||
self.bound = bound |
|||
self.numSimulations = numSimulations |
|||
|
|||
allStates = set([state.state_id for state in allStateActionPairs]) |
|||
allStates = allStates.difference(set(deadlockStates)) |
|||
allStates = allStates.difference(set(reachedStates)) |
|||
self.allStates = np.array(list(allStates)) |
|||
|
|||
def _pickRandomTestCase(self): |
|||
testCase = np.random.choice(self.allStates, 1)[0] |
|||
#self.allStates = np.delete(self.allStates, testCase) |
|||
return testCase |
|||
|
|||
def _simulate(self, initialStateId): |
|||
i = 0 |
|||
|
|||
actionId = self.strategy[initialStateId] |
|||
nextStatePair = (initialStateId, actionId) |
|||
|
|||
while i < self.bound: |
|||
i += 1 |
|||
nextStateProbabilities = self.allStateActionPairs[nextStatePair] |
|||
weights = list() |
|||
nextStateIds = list() |
|||
for nextStateId, probability in nextStateProbabilities.items(): |
|||
weights.append(probability) |
|||
nextStateIds.append(nextStateId) |
|||
nextStateId = np.random.choice(nextStateIds, 1, p=weights)[0] |
|||
if nextStateId in self.deadlockStates: |
|||
return Verdict.FAIL, i |
|||
if nextStateId in self.reachedStates: |
|||
return Verdict.PASS, i |
|||
nextStatePair = (nextStateId, self.strategy[nextStateId]) |
|||
return Verdict.INCONCLUSIVE, i |
|||
|
|||
def runTest(self): |
|||
testCase = self._pickRandomTestCase() |
|||
|
|||
histogram = [0,0,0] |
|||
for i in range(self.numSimulations): |
|||
result, numQueries = self._simulate(testCase) |
|||
if result == Verdict.FAIL: |
|||
return testCase, Verdict.FAIL, numQueries |
|||
return testCase, Verdict.INCONCLUSIVE, numQueries |
@ -0,0 +1,376 @@ |
|||
#!/usr/bin/python3 |
|||
|
|||
import re, sys, os, shutil, fileinput, subprocess, argparse |
|||
from dataclasses import dataclass, field |
|||
import visvis as vv |
|||
import numpy as np |
|||
import argparse |
|||
|
|||
from translate import translateTransitions, readLabels |
|||
from plotting import VisVisPlotter |
|||
from simulation import Simulator, Verdict |
|||
|
|||
|
|||
def convert(tuples): |
|||
return dict(tuples) |
|||
|
|||
def getBasename(filename): |
|||
return os.path.basename(filename) |
|||
|
|||
def traFileWithIteration(filename, iteration): |
|||
return os.path.splitext(filename)[0] + f"_{iteration:03}.tra" |
|||
|
|||
def copyFile(filename, newFilename): |
|||
shutil.copy(filename, newFilename) |
|||
|
|||
|
|||
def execute(command, verbose=False): |
|||
if verbose: print(f"Executing {command}") |
|||
os.system(command) |
|||
|
|||
@dataclass(frozen=True) |
|||
class State: |
|||
id: int |
|||
x: float |
|||
x_vel: float |
|||
y: float |
|||
y_vel: float |
|||
z: float |
|||
z_vel: float |
|||
|
|||
def default_value(): |
|||
return {'action' : None, 'choiceValue' : None} |
|||
|
|||
|
|||
@dataclass(frozen=True) |
|||
class StateValue: |
|||
ranking: float |
|||
choices: dict = field(default_factory=default_value) |
|||
|
|||
@dataclass(frozen=False) |
|||
class TestResult: |
|||
prob_pes_min: float |
|||
prob_pes_max: float |
|||
prob_pes_avg: float |
|||
prob_opt_min: float |
|||
prob_opt_max: float |
|||
prob_opt_avg: float |
|||
min_min: float |
|||
min_max: float |
|||
|
|||
def csv(self, ws=" "): |
|||
return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}" |
|||
|
|||
def parseStrategy(strategyFile, allStateActionPairs, time_index=0): |
|||
strategy = dict() |
|||
with open(strategyFile) as strategyLines: |
|||
for line in strategyLines: |
|||
line = line.replace("(","").replace(")","").replace("\n", "") |
|||
explode = re.split(",|=", line) |
|||
stateId = int(explode[0]) + 3 |
|||
if stateId < 3: continue |
|||
if int(explode[1]) != time_index: continue |
|||
try: |
|||
strategy[stateId] = allStateActionPairs[stateId].index(explode[2]) |
|||
except KeyError as e: |
|||
pass |
|||
|
|||
return strategy |
|||
|
|||
def queryStrategy(strategy, stateId): |
|||
try: |
|||
return strategy[stateId] |
|||
except: |
|||
return -1 |
|||
|
|||
|
|||
def callTempest(files, reward, bound=3): |
|||
property_str = "!(\"failed\" | \"reached\")" |
|||
if True: |
|||
prop = f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
|||
prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" |
|||
else: |
|||
prop = f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str} );" |
|||
prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str} );" |
|||
command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' " |
|||
|
|||
results = list() |
|||
try: |
|||
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') |
|||
for line in output: |
|||
if "Result" in line and not len(results) >= 10: |
|||
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) |
|||
if range_value: |
|||
results.append(float(range_value.group(2))) |
|||
results.append(float(range_value.group(3))) |
|||
else: |
|||
value = re.search(r"(.*:)(.*)", line) |
|||
results.append(float(value.group(2))) |
|||
except subprocess.CalledProcessError as e: |
|||
print(e.output) |
|||
#results.append(-1) |
|||
#results.append(-1) |
|||
return TestResult(*(tuple(results))) |
|||
|
|||
def parseRanking(filename, allStates): |
|||
state_ranking = dict() |
|||
try: |
|||
with open(filename, "r") as f: |
|||
filecontent = f.readlines() |
|||
for line in filecontent: |
|||
stateId = int(re.findall(r"^\d+", line)[0]) |
|||
values = re.findall(r":(-?\d+\.?\d*),?", line) |
|||
ranking_value = float(values[0]) |
|||
choices = {i : float(value) for i,value in enumerate(values[1:])} |
|||
state = allStates[stateId] |
|||
value = StateValue(ranking_value, choices) |
|||
state_ranking[state] = value |
|||
if len(state_ranking) == 0: return |
|||
all_values = [x.ranking for x in state_ranking.values()] |
|||
max_value = max(all_values) |
|||
min_value = min(all_values) |
|||
new_state_ranking = {} |
|||
for state, value in state_ranking.items(): |
|||
choices = value.choices |
|||
try: |
|||
new_value = (value.ranking - min_value) / (max_value - min_value) |
|||
except ZeroDivisionError as e: |
|||
new_value = 0.0 |
|||
new_state_ranking[state] = StateValue(new_value, choices) |
|||
state_ranking = new_state_ranking |
|||
except EnvironmentError: |
|||
print("TODO file not available. Exiting.") |
|||
sys.exit(1) |
|||
return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)} |
|||
|
|||
def parseStateValuations(filename): |
|||
all_states = dict() |
|||
maxStateId = -1 |
|||
for i in [0,1,2]: |
|||
dummy_values = [i] * 7 |
|||
all_states[i] = State(*dummy_values) |
|||
with open(filename) as stateValuations: |
|||
for line in stateValuations: |
|||
values = re.findall(r"(-?\d+\.?\d*),?", line) |
|||
values = [int(values[0])] + [float(v) for v in values[1:]] |
|||
all_states[values[0]] = State(*values) |
|||
if values[0] > maxStateId: maxStateId = values[0] |
|||
dummy_values = [maxStateId + 1] * 7 |
|||
all_states[maxStateId + 1] = State(*dummy_values) |
|||
return all_states |
|||
|
|||
def parseResults(allStates): |
|||
state_to_values = dict() |
|||
with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer: |
|||
for max_line, min_line in zip(maximizer, minimizer): |
|||
max_values = re.findall(r"(-?\d+\.?\d*),?", max_line) |
|||
min_values = re.findall(r"(-?\d+\.?\d*),?", min_line) |
|||
if max_values[0] != min_values[0]: |
|||
print("min/max files do not match.") |
|||
assert(False) |
|||
stateId = int(max_values[0]) |
|||
min_result = float(min_values[1]) |
|||
max_result = float(max_values[1]) |
|||
value = (min_result, max_result, max_result - min_result) |
|||
state_to_values[stateId] = value |
|||
return state_to_values |
|||
|
|||
|
|||
def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration): |
|||
stateIdRegex = re.compile(f"^{stateId}\s") |
|||
for line in fileinput.input(filename, inplace = True): |
|||
if not stateIdRegex.match(line): |
|||
print(line, end="") |
|||
else: |
|||
explode = line.split(" ") |
|||
if int(explode[1]) == chosenActionIndex: |
|||
print(line, end="") |
|||
|
|||
def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration): |
|||
stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()])) |
|||
for line in fileinput.input(filename, inplace = True): |
|||
result = stateIdsRegex.match(line) |
|||
if not result: |
|||
print(line, end="") |
|||
else: |
|||
actionIndex = stateActionPairsToTrim[int(result[0])] |
|||
explode = line.split(" ") |
|||
if int(explode[1]) == actionIndex: |
|||
print(line, end="") |
|||
|
|||
def getTopNStates(rankedStates, n, threshold): |
|||
if n != 0: |
|||
return dict(list(rankedStates.items())[-n:]) |
|||
else: |
|||
return {state:value for state,value in rankedStates.items() if value.ranking >= threshold} |
|||
|
|||
def getNRandomStates(rankedStates, n, testedStates): |
|||
stateIds = [state.id for state in rankedStates.keys()] |
|||
notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates]) |
|||
if len(notYetTestedStates) >= n: |
|||
return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)] |
|||
else: |
|||
return notYetTestedStates |
|||
|
|||
def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False): |
|||
|
|||
all_states = parseStateValuations("MDP_state_valuations") |
|||
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
|||
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
|||
strategy = parseStrategy(straFile, stateToActions) |
|||
|
|||
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) |
|||
if plotting: plotter.plotScenario() |
|||
|
|||
copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra") |
|||
|
|||
iteration = 0 |
|||
#testsPerIteration = refinementSteps |
|||
#refinementThreshold = |
|||
numTestedStates = 0 |
|||
totalIterations = 60 |
|||
|
|||
testedStates = list() |
|||
while iteration < totalIterations: |
|||
print(f"{iteration:03}", end="\t") |
|||
sys.stdout.flush() |
|||
currentTraFile = traFileWithIteration("MDP_" + traFile, iteration) |
|||
nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1) |
|||
testResult = callTempest(f"{currentTraFile} MDP_{labFile}", "saferew", horizonBound) |
|||
state_ranking = parseRanking("action_ranking", all_states) |
|||
copyFile("action_ranking", f"action_ranking_{iteration:03}") |
|||
copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}") |
|||
copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}") |
|||
|
|||
if not ablationTesting: |
|||
importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound) |
|||
statesToTest = [state.id for state in importantStates.keys()] |
|||
statesToPlot = importantStates |
|||
else: |
|||
statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates)) |
|||
testedStates += statesToTest |
|||
statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest} |
|||
|
|||
|
|||
copyFile(currentTraFile, nextTraFile) |
|||
stateActionPairsToTrim = dict() |
|||
for testState in statesToTest: |
|||
chosenActionIndex = queryStrategy(strategy, testState) |
|||
if chosenActionIndex != -1: |
|||
stateActionPairsToTrim[testState] = chosenActionIndex |
|||
stateEstimates = parseResults(all_states) |
|||
results = [0,0,0] |
|||
|
|||
failureStates = list() |
|||
validatedStates = list() |
|||
for state, estimates in stateEstimates.items(): |
|||
if state in deadlockStates or state in reachedStates: |
|||
continue |
|||
if estimates[0] > 0.05: |
|||
results[1] += 1 |
|||
failureStates.append(all_states[state]) |
|||
#print(f"{state}: {estimates}") |
|||
elif estimates[1] <= 0.05: |
|||
results[0] += 1 |
|||
validatedStates.append(all_states[state]) |
|||
else: |
|||
results[2] += 1 |
|||
|
|||
|
|||
removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration) |
|||
print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}") |
|||
if results[2] == 0: |
|||
sys.exit(0) |
|||
numTestedStates += len(statesToTest) |
|||
iteration += 1 |
|||
|
|||
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True) |
|||
if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6)) |
|||
if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05") |
|||
|
|||
def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False): |
|||
all_states = parseStateValuations("MDP_state_valuations") |
|||
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
|||
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
|||
strategy = parseStrategy(straFile, stateToActions) |
|||
|
|||
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) |
|||
if plotting: plotter.plotScenario() |
|||
passingStates = list() |
|||
|
|||
randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound) |
|||
i = 0 |
|||
print("Starting with random testing.") |
|||
numQueries = 0 |
|||
failureStates = list() |
|||
while numQueries <= maxQueries: |
|||
if i >= 500: |
|||
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6)) |
|||
if plotting: plotter.takeScreenshot(iteration, prefix="random_testing") |
|||
if plotting: plotter.turnCamera() |
|||
if plotting: input("") |
|||
print(f"{numQueries} {len(failureStates)} ") |
|||
i = 0 |
|||
testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest() |
|||
i += queriesForThisTestCase |
|||
numQueries += queriesForThisTestCase |
|||
stateValuation = all_states[testCase] |
|||
if testResult == Verdict.FAIL: |
|||
failureStates.append(stateValuation) |
|||
|
|||
print(f"{numQueries} {len(failureStates)} ") |
|||
|
|||
def parseArgs(): |
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--tra', type=str, required=True, help='Path to .tra file.') |
|||
parser.add_argument('--lab', type=str, required=True, help='Path to .lab file.') |
|||
parser.add_argument('--rew', type=str, required=True, help='Path to .rew file.') |
|||
parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.') |
|||
|
|||
refinement = parser.add_mutually_exclusive_group(required=True) |
|||
refinement.add_argument('--refinement-steps', type=int, default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.') |
|||
refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.') |
|||
|
|||
parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.') |
|||
parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.') |
|||
|
|||
random_testing = parser.add_mutually_exclusive_group() |
|||
random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.") |
|||
random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.') |
|||
|
|||
parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.') |
|||
parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.') |
|||
return parser.parse_args() |
|||
|
|||
if __name__ == '__main__': |
|||
args = parseArgs() |
|||
|
|||
traFile = args.tra |
|||
labFile = args.lab |
|||
straFile = args.stra |
|||
rewFile = args.rew |
|||
|
|||
ablationTesting = args.ablation |
|||
plotting = args.plotting |
|||
stepwisePlotting = args.stepwise |
|||
|
|||
maxQueriesForRandomTesting = args.random |
|||
horizonBound = args.bound |
|||
|
|||
refinementSteps = args.refinement_steps |
|||
refinementBound = args.refinement_bound |
|||
|
|||
if maxQueriesForRandomTesting == 0: #akward way to test for this... |
|||
main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting) |
|||
else: |
|||
randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting) |
@ -0,0 +1,135 @@ |
|||
#!/usr/bin/python3 |
|||
|
|||
import sys |
|||
#import re |
|||
import json |
|||
import numpy as np |
|||
from collections import deque |
|||
|
|||
from dataclasses import dataclass |
|||
|
|||
|
|||
@dataclass(frozen=False) |
|||
class StateAction: |
|||
state_id: int |
|||
action_id: int |
|||
next_state_probabilities: dict |
|||
action_name: str |
|||
|
|||
def normalizeDistribution(self): |
|||
weight = np.sum([value for key, value in self.next_state_probabilities.items()]) |
|||
self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() } |
|||
|
|||
|
|||
def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId): |
|||
current_state_id = -1 |
|||
current_action_id = -1 |
|||
all_state_action_pairs = list() |
|||
with open(traFile) as transitions: |
|||
next(transitions) |
|||
for line in transitions: |
|||
line = line.replace("\n","") |
|||
explode = line.split(" ") |
|||
if len(explode) < 2: continue |
|||
interval = json.loads(explode[3]) |
|||
probability = (interval[0] + interval[1])/2 |
|||
state_id = int(explode[0]) |
|||
action_id = int(explode[1]) |
|||
next_state_id = int(explode[2]) |
|||
if len(explode) >= 5: |
|||
action_name = explode[4] |
|||
else: |
|||
action_name = "" |
|||
#print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.") |
|||
if state_id in [0, 1, 2]: |
|||
continue |
|||
|
|||
next_state_probabilities = {next_state_id: probability} |
|||
if current_state_id != state_id: |
|||
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) |
|||
all_state_action_pairs.append(new_state_action_pair) |
|||
current_state_id = state_id |
|||
current_action_id = action_id |
|||
|
|||
elif current_action_id != action_id: |
|||
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) |
|||
all_state_action_pairs.append(new_state_action_pair) |
|||
current_action_id = action_id |
|||
else: |
|||
all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability |
|||
|
|||
# we need to sort the deadlock and reached states to insert them while building the .tra file |
|||
deadlockStates = [(state, 0) for state in deadlockStates] |
|||
reachedStates = [(state, maxStateId) for state in reachedStates] |
|||
final_states = deadlockStates + reachedStates |
|||
final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True)) |
|||
|
|||
with open("MDP_" + traFile, "w") as new_transitions_file: |
|||
new_transitions_file.write("mdp\n") |
|||
new_transitions_file.write(f"0 0 {maxStateId} 1.0\n") |
|||
for entry in all_state_action_pairs: |
|||
entry.normalizeDistribution() |
|||
source_state = int(entry.state_id) |
|||
while final_states and int(final_states[-1][0]) < source_state: |
|||
final_state = final_states.pop() |
|||
if int(final_state[0]) == 0: continue |
|||
new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n") |
|||
for next_state_id, probability in entry.next_state_probabilities.items(): |
|||
new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n") |
|||
new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n") |
|||
|
|||
state_to_actions = dict() |
|||
for state_action_pair in all_state_action_pairs: |
|||
if state_action_pair.state_id in state_to_actions: |
|||
state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name) |
|||
else: |
|||
state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name] |
|||
return state_to_actions, all_state_action_pairs |
|||
|
|||
def readLabels(labFile): |
|||
deadlockStates = list() |
|||
reachedStates = list() |
|||
with open(labFile) as states: |
|||
newLabFile = "MDP_" + labFile |
|||
newLabels = open(newLabFile, "w") |
|||
optRewards = open(newLabFile + ".optrew", "w") |
|||
safetyRewards = open(newLabFile + ".saferew", "w") |
|||
labels = ["init", "deadlock", "reached", "failed"] |
|||
next(states) |
|||
newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n") |
|||
maxStateId = -1 |
|||
for line in states: |
|||
line = line.replace(":","").replace("\n", "") |
|||
explode = line.split(" ") |
|||
newLabel = f"{explode[0]} " |
|||
if int(explode[0]) > maxStateId: maxStateId = int(explode[0]) |
|||
if int(explode[0]) == 0: |
|||
safetyRewards.write(f"{explode[0]} -100\n") |
|||
optRewards.write(f"{explode[0]} -100\n") |
|||
#if "3" in explode: |
|||
# safetyRewards.write(f"{explode[0]} -100\n") |
|||
# optRewards.write(f"{explode[0]} -100\n") |
|||
elif "2" in explode: |
|||
optRewards.write(f"{explode[0]} 100\n") |
|||
else: |
|||
optRewards.write(f"{explode[0]} -1\n") |
|||
for labelIndex in explode[1:]: |
|||
# sink states should not be deadlock states anymore: |
|||
if labelIndex == "1": |
|||
deadlockStates.append(int(explode[0])) |
|||
continue |
|||
if labelIndex == "2": |
|||
reachedStates.append(int(explode[0])) |
|||
continue |
|||
newLabel += f"{labels[int(labelIndex)]} " |
|||
newLabels.write(newLabel + "\n") |
|||
return deadlockStates, reachedStates, maxStateId + 1 |
|||
|
|||
def main(traFile, labFile): |
|||
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
|||
translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
|||
|
|||
if __name__ == '__main__': |
|||
traFile = sys.argv[1] |
|||
labFile = sys.argv[2] |
|||
main(traFile, labFile) |
Write
Preview
Loading…
Cancel
Save
Reference in new issue