sp
11 months ago
commit
b3be734b8f
4 changed files with 667 additions and 0 deletions
-
88plotting.py
-
68simulation.py
-
376test_model.py
-
135translate.py
@ -0,0 +1,88 @@ |
|||||
|
#!/usr/bin/python3 |
||||
|
|
||||
|
import visvis as vv |
||||
|
import numpy as np |
||||
|
|
||||
|
import time |
||||
|
|
||||
|
def translateArrayToVV(array): |
||||
|
scaling = (float(array[0][1]) - float(array[0][0]), float(array[2][1]) - float(array[2][0]), float(array[4][1]) - float(array[4][0])) |
||||
|
translation = (float(array[0][0]) + scaling[0] * 0.5, float(array[2][0]) + scaling[0] * 0.5, float(array[4][0]) + scaling[0] * 0.5) |
||||
|
return translation, scaling |
||||
|
|
||||
|
class VisVisPlotter(): |
||||
|
def __init__(self, stateValuations, goalStates, deadlockStates, stepwise): |
||||
|
self.app = vv.use() |
||||
|
self.fig = vv.clf() |
||||
|
self.ax = vv.cla() |
||||
|
self.stateColor = (0,0,0.75,0.4) |
||||
|
self.failColor = (0.8,0.8,0.8,1.0) # (1,0,0,0.8) |
||||
|
self.goalColor = (0,1,0,1.0) |
||||
|
self.stateScaling = (2,2,2) #(.5,.5,.5) |
||||
|
|
||||
|
self.plotedStates = set() |
||||
|
self.plotedMeshes = set() |
||||
|
self.stepwise = stepwise |
||||
|
|
||||
|
auxStates = [0,1,2,25518] |
||||
|
self.goals = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in goalStates and not stateId in auxStates]) |
||||
|
self.fails = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in deadlockStates and not stateId in auxStates]) |
||||
|
|
||||
|
def plotScenario(self, saveScreenshot=True): |
||||
|
for goal in self.goals: |
||||
|
state = vv.solidBox(goal, scaling=self.stateScaling) |
||||
|
state.faceColor = self.goalColor |
||||
|
for fail in self.fails: |
||||
|
state = vv.solidBox(fail, scaling=self.stateScaling) |
||||
|
state.faceColor = self.failColor |
||||
|
|
||||
|
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
||||
|
self.ax.SetView({'zoom':0.025, 'elevation':20, 'azimuth':30}) |
||||
|
if saveScreenshot: vv.screenshot("000.png", sf=3, bg='w', ob=vv.gcf()) |
||||
|
|
||||
|
def run(self): |
||||
|
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
||||
|
self.app.Run() |
||||
|
|
||||
|
def clear(self): |
||||
|
axes = vv.gca() |
||||
|
axes.Clear() |
||||
|
|
||||
|
def plotStates(self, states, coloring="", removeMeshes=False): |
||||
|
if self.stepwise and removeMeshes: |
||||
|
self.clear() |
||||
|
self.plotScenario(saveScreenshot=False) |
||||
|
if not coloring: |
||||
|
coloring = self.stateColor |
||||
|
plotedRegions = set() |
||||
|
for state in states: |
||||
|
if state in self.plotedStates: continue # what are the implications of this? |
||||
|
coordinates = (state.x, state.y, state.z) |
||||
|
print(f"plotting {state} at {coordinates}") |
||||
|
if coordinates in plotedRegions: continue |
||||
|
state = vv.solidBox(coordinates, scaling=(1,1,1))#(0.5,0.5,0.5)) |
||||
|
state.faceColor = coloring |
||||
|
plotedRegions.add(coordinates) |
||||
|
if self.stepwise: |
||||
|
self.plotedMeshes.add(state) |
||||
|
self.plotedStates.update(states) |
||||
|
|
||||
|
def takeScreenshot(self, iteration, prefix=""): |
||||
|
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
||||
|
#config = [(90, 00), (45, 00), (0, 00), (-45, 00), (-90, 00)] |
||||
|
config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] |
||||
|
for elevation, azimuth in config: |
||||
|
filename = f"{prefix}{'_' if prefix else ''}{iteration:03}_{elevation}_{azimuth}.png" |
||||
|
print(f"Saving Screenshot to {filename}") |
||||
|
self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) |
||||
|
vv.screenshot(filename, sf=3, bg='w', ob=vv.gcf()) |
||||
|
|
||||
|
def turnCamera(self): |
||||
|
config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] |
||||
|
self.ax.SetLimits((-16,16),(-10,10),(-7,7)) |
||||
|
for elevation, azimuth in config: |
||||
|
self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
@ -0,0 +1,68 @@ |
|||||
|
import sys |
||||
|
from enum import Flag, auto |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
|
||||
|
class Verdict(Flag): |
||||
|
INCONCLUSIVE = auto() |
||||
|
PASS = auto() |
||||
|
FAIL = auto() |
||||
|
|
||||
|
|
||||
|
|
||||
|
class Simulator(): |
||||
|
def __init__(self, allStateActionPairs, strategy, deadlockStates, reachedStates, bound=3, numSimulations=1): |
||||
|
self.allStateActionPairs = { ( pair.state_id, pair.action_id ) : pair.next_state_probabilities for pair in allStateActionPairs } |
||||
|
self.strategy = strategy |
||||
|
self.deadlockStates = deadlockStates |
||||
|
self.reachedStates = reachedStates |
||||
|
|
||||
|
#print(f"Deadlock: {self.deadlockStates}") |
||||
|
#print(f"GoalStates: {self.reachedStates}") |
||||
|
|
||||
|
|
||||
|
self.bound = bound |
||||
|
self.numSimulations = numSimulations |
||||
|
|
||||
|
allStates = set([state.state_id for state in allStateActionPairs]) |
||||
|
allStates = allStates.difference(set(deadlockStates)) |
||||
|
allStates = allStates.difference(set(reachedStates)) |
||||
|
self.allStates = np.array(list(allStates)) |
||||
|
|
||||
|
def _pickRandomTestCase(self): |
||||
|
testCase = np.random.choice(self.allStates, 1)[0] |
||||
|
#self.allStates = np.delete(self.allStates, testCase) |
||||
|
return testCase |
||||
|
|
||||
|
def _simulate(self, initialStateId): |
||||
|
i = 0 |
||||
|
|
||||
|
actionId = self.strategy[initialStateId] |
||||
|
nextStatePair = (initialStateId, actionId) |
||||
|
|
||||
|
while i < self.bound: |
||||
|
i += 1 |
||||
|
nextStateProbabilities = self.allStateActionPairs[nextStatePair] |
||||
|
weights = list() |
||||
|
nextStateIds = list() |
||||
|
for nextStateId, probability in nextStateProbabilities.items(): |
||||
|
weights.append(probability) |
||||
|
nextStateIds.append(nextStateId) |
||||
|
nextStateId = np.random.choice(nextStateIds, 1, p=weights)[0] |
||||
|
if nextStateId in self.deadlockStates: |
||||
|
return Verdict.FAIL, i |
||||
|
if nextStateId in self.reachedStates: |
||||
|
return Verdict.PASS, i |
||||
|
nextStatePair = (nextStateId, self.strategy[nextStateId]) |
||||
|
return Verdict.INCONCLUSIVE, i |
||||
|
|
||||
|
def runTest(self): |
||||
|
testCase = self._pickRandomTestCase() |
||||
|
|
||||
|
histogram = [0,0,0] |
||||
|
for i in range(self.numSimulations): |
||||
|
result, numQueries = self._simulate(testCase) |
||||
|
if result == Verdict.FAIL: |
||||
|
return testCase, Verdict.FAIL, numQueries |
||||
|
return testCase, Verdict.INCONCLUSIVE, numQueries |
@ -0,0 +1,376 @@ |
|||||
|
#!/usr/bin/python3 |
||||
|
|
||||
|
import re, sys, os, shutil, fileinput, subprocess, argparse |
||||
|
from dataclasses import dataclass, field |
||||
|
import visvis as vv |
||||
|
import numpy as np |
||||
|
import argparse |
||||
|
|
||||
|
from translate import translateTransitions, readLabels |
||||
|
from plotting import VisVisPlotter |
||||
|
from simulation import Simulator, Verdict |
||||
|
|
||||
|
|
||||
|
def convert(tuples): |
||||
|
return dict(tuples) |
||||
|
|
||||
|
def getBasename(filename): |
||||
|
return os.path.basename(filename) |
||||
|
|
||||
|
def traFileWithIteration(filename, iteration): |
||||
|
return os.path.splitext(filename)[0] + f"_{iteration:03}.tra" |
||||
|
|
||||
|
def copyFile(filename, newFilename): |
||||
|
shutil.copy(filename, newFilename) |
||||
|
|
||||
|
|
||||
|
def execute(command, verbose=False): |
||||
|
if verbose: print(f"Executing {command}") |
||||
|
os.system(command) |
||||
|
|
||||
|
@dataclass(frozen=True) |
||||
|
class State: |
||||
|
id: int |
||||
|
x: float |
||||
|
x_vel: float |
||||
|
y: float |
||||
|
y_vel: float |
||||
|
z: float |
||||
|
z_vel: float |
||||
|
|
||||
|
def default_value(): |
||||
|
return {'action' : None, 'choiceValue' : None} |
||||
|
|
||||
|
|
||||
|
@dataclass(frozen=True) |
||||
|
class StateValue: |
||||
|
ranking: float |
||||
|
choices: dict = field(default_factory=default_value) |
||||
|
|
||||
|
@dataclass(frozen=False) |
||||
|
class TestResult: |
||||
|
prob_pes_min: float |
||||
|
prob_pes_max: float |
||||
|
prob_pes_avg: float |
||||
|
prob_opt_min: float |
||||
|
prob_opt_max: float |
||||
|
prob_opt_avg: float |
||||
|
min_min: float |
||||
|
min_max: float |
||||
|
|
||||
|
def csv(self, ws=" "): |
||||
|
return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}" |
||||
|
|
||||
|
def parseStrategy(strategyFile, allStateActionPairs, time_index=0): |
||||
|
strategy = dict() |
||||
|
with open(strategyFile) as strategyLines: |
||||
|
for line in strategyLines: |
||||
|
line = line.replace("(","").replace(")","").replace("\n", "") |
||||
|
explode = re.split(",|=", line) |
||||
|
stateId = int(explode[0]) + 3 |
||||
|
if stateId < 3: continue |
||||
|
if int(explode[1]) != time_index: continue |
||||
|
try: |
||||
|
strategy[stateId] = allStateActionPairs[stateId].index(explode[2]) |
||||
|
except KeyError as e: |
||||
|
pass |
||||
|
|
||||
|
return strategy |
||||
|
|
||||
|
def queryStrategy(strategy, stateId): |
||||
|
try: |
||||
|
return strategy[stateId] |
||||
|
except: |
||||
|
return -1 |
||||
|
|
||||
|
|
||||
|
def callTempest(files, reward, bound=3): |
||||
|
property_str = "!(\"failed\" | \"reached\")" |
||||
|
if True: |
||||
|
prop = f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" |
||||
|
prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" |
||||
|
else: |
||||
|
prop = f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str} );" |
||||
|
prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str} );" |
||||
|
command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' " |
||||
|
|
||||
|
results = list() |
||||
|
try: |
||||
|
output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') |
||||
|
for line in output: |
||||
|
if "Result" in line and not len(results) >= 10: |
||||
|
range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) |
||||
|
if range_value: |
||||
|
results.append(float(range_value.group(2))) |
||||
|
results.append(float(range_value.group(3))) |
||||
|
else: |
||||
|
value = re.search(r"(.*:)(.*)", line) |
||||
|
results.append(float(value.group(2))) |
||||
|
except subprocess.CalledProcessError as e: |
||||
|
print(e.output) |
||||
|
#results.append(-1) |
||||
|
#results.append(-1) |
||||
|
return TestResult(*(tuple(results))) |
||||
|
|
||||
|
def parseRanking(filename, allStates): |
||||
|
state_ranking = dict() |
||||
|
try: |
||||
|
with open(filename, "r") as f: |
||||
|
filecontent = f.readlines() |
||||
|
for line in filecontent: |
||||
|
stateId = int(re.findall(r"^\d+", line)[0]) |
||||
|
values = re.findall(r":(-?\d+\.?\d*),?", line) |
||||
|
ranking_value = float(values[0]) |
||||
|
choices = {i : float(value) for i,value in enumerate(values[1:])} |
||||
|
state = allStates[stateId] |
||||
|
value = StateValue(ranking_value, choices) |
||||
|
state_ranking[state] = value |
||||
|
if len(state_ranking) == 0: return |
||||
|
all_values = [x.ranking for x in state_ranking.values()] |
||||
|
max_value = max(all_values) |
||||
|
min_value = min(all_values) |
||||
|
new_state_ranking = {} |
||||
|
for state, value in state_ranking.items(): |
||||
|
choices = value.choices |
||||
|
try: |
||||
|
new_value = (value.ranking - min_value) / (max_value - min_value) |
||||
|
except ZeroDivisionError as e: |
||||
|
new_value = 0.0 |
||||
|
new_state_ranking[state] = StateValue(new_value, choices) |
||||
|
state_ranking = new_state_ranking |
||||
|
except EnvironmentError: |
||||
|
print("TODO file not available. Exiting.") |
||||
|
sys.exit(1) |
||||
|
return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)} |
||||
|
|
||||
|
def parseStateValuations(filename): |
||||
|
all_states = dict() |
||||
|
maxStateId = -1 |
||||
|
for i in [0,1,2]: |
||||
|
dummy_values = [i] * 7 |
||||
|
all_states[i] = State(*dummy_values) |
||||
|
with open(filename) as stateValuations: |
||||
|
for line in stateValuations: |
||||
|
values = re.findall(r"(-?\d+\.?\d*),?", line) |
||||
|
values = [int(values[0])] + [float(v) for v in values[1:]] |
||||
|
all_states[values[0]] = State(*values) |
||||
|
if values[0] > maxStateId: maxStateId = values[0] |
||||
|
dummy_values = [maxStateId + 1] * 7 |
||||
|
all_states[maxStateId + 1] = State(*dummy_values) |
||||
|
return all_states |
||||
|
|
||||
|
def parseResults(allStates): |
||||
|
state_to_values = dict() |
||||
|
with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer: |
||||
|
for max_line, min_line in zip(maximizer, minimizer): |
||||
|
max_values = re.findall(r"(-?\d+\.?\d*),?", max_line) |
||||
|
min_values = re.findall(r"(-?\d+\.?\d*),?", min_line) |
||||
|
if max_values[0] != min_values[0]: |
||||
|
print("min/max files do not match.") |
||||
|
assert(False) |
||||
|
stateId = int(max_values[0]) |
||||
|
min_result = float(min_values[1]) |
||||
|
max_result = float(max_values[1]) |
||||
|
value = (min_result, max_result, max_result - min_result) |
||||
|
state_to_values[stateId] = value |
||||
|
return state_to_values |
||||
|
|
||||
|
|
||||
|
def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration): |
||||
|
stateIdRegex = re.compile(f"^{stateId}\s") |
||||
|
for line in fileinput.input(filename, inplace = True): |
||||
|
if not stateIdRegex.match(line): |
||||
|
print(line, end="") |
||||
|
else: |
||||
|
explode = line.split(" ") |
||||
|
if int(explode[1]) == chosenActionIndex: |
||||
|
print(line, end="") |
||||
|
|
||||
|
def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration): |
||||
|
stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()])) |
||||
|
for line in fileinput.input(filename, inplace = True): |
||||
|
result = stateIdsRegex.match(line) |
||||
|
if not result: |
||||
|
print(line, end="") |
||||
|
else: |
||||
|
actionIndex = stateActionPairsToTrim[int(result[0])] |
||||
|
explode = line.split(" ") |
||||
|
if int(explode[1]) == actionIndex: |
||||
|
print(line, end="") |
||||
|
|
||||
|
def getTopNStates(rankedStates, n, threshold): |
||||
|
if n != 0: |
||||
|
return dict(list(rankedStates.items())[-n:]) |
||||
|
else: |
||||
|
return {state:value for state,value in rankedStates.items() if value.ranking >= threshold} |
||||
|
|
||||
|
def getNRandomStates(rankedStates, n, testedStates): |
||||
|
stateIds = [state.id for state in rankedStates.keys()] |
||||
|
notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates]) |
||||
|
if len(notYetTestedStates) >= n: |
||||
|
return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)] |
||||
|
else: |
||||
|
return notYetTestedStates |
||||
|
|
||||
|
def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False): |
||||
|
|
||||
|
all_states = parseStateValuations("MDP_state_valuations") |
||||
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
||||
|
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
||||
|
strategy = parseStrategy(straFile, stateToActions) |
||||
|
|
||||
|
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) |
||||
|
if plotting: plotter.plotScenario() |
||||
|
|
||||
|
copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra") |
||||
|
|
||||
|
iteration = 0 |
||||
|
#testsPerIteration = refinementSteps |
||||
|
#refinementThreshold = |
||||
|
numTestedStates = 0 |
||||
|
totalIterations = 60 |
||||
|
|
||||
|
testedStates = list() |
||||
|
while iteration < totalIterations: |
||||
|
print(f"{iteration:03}", end="\t") |
||||
|
sys.stdout.flush() |
||||
|
currentTraFile = traFileWithIteration("MDP_" + traFile, iteration) |
||||
|
nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1) |
||||
|
testResult = callTempest(f"{currentTraFile} MDP_{labFile}", "saferew", horizonBound) |
||||
|
state_ranking = parseRanking("action_ranking", all_states) |
||||
|
copyFile("action_ranking", f"action_ranking_{iteration:03}") |
||||
|
copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}") |
||||
|
copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}") |
||||
|
|
||||
|
if not ablationTesting: |
||||
|
importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound) |
||||
|
statesToTest = [state.id for state in importantStates.keys()] |
||||
|
statesToPlot = importantStates |
||||
|
else: |
||||
|
statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates)) |
||||
|
testedStates += statesToTest |
||||
|
statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest} |
||||
|
|
||||
|
|
||||
|
copyFile(currentTraFile, nextTraFile) |
||||
|
stateActionPairsToTrim = dict() |
||||
|
for testState in statesToTest: |
||||
|
chosenActionIndex = queryStrategy(strategy, testState) |
||||
|
if chosenActionIndex != -1: |
||||
|
stateActionPairsToTrim[testState] = chosenActionIndex |
||||
|
stateEstimates = parseResults(all_states) |
||||
|
results = [0,0,0] |
||||
|
|
||||
|
failureStates = list() |
||||
|
validatedStates = list() |
||||
|
for state, estimates in stateEstimates.items(): |
||||
|
if state in deadlockStates or state in reachedStates: |
||||
|
continue |
||||
|
if estimates[0] > 0.05: |
||||
|
results[1] += 1 |
||||
|
failureStates.append(all_states[state]) |
||||
|
#print(f"{state}: {estimates}") |
||||
|
elif estimates[1] <= 0.05: |
||||
|
results[0] += 1 |
||||
|
validatedStates.append(all_states[state]) |
||||
|
else: |
||||
|
results[2] += 1 |
||||
|
|
||||
|
|
||||
|
removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration) |
||||
|
print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}") |
||||
|
if results[2] == 0: |
||||
|
sys.exit(0) |
||||
|
numTestedStates += len(statesToTest) |
||||
|
iteration += 1 |
||||
|
|
||||
|
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True) |
||||
|
if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6)) |
||||
|
if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05") |
||||
|
|
||||
|
def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False): |
||||
|
all_states = parseStateValuations("MDP_state_valuations") |
||||
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
||||
|
stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
||||
|
strategy = parseStrategy(straFile, stateToActions) |
||||
|
|
||||
|
if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) |
||||
|
if plotting: plotter.plotScenario() |
||||
|
passingStates = list() |
||||
|
|
||||
|
randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound) |
||||
|
i = 0 |
||||
|
print("Starting with random testing.") |
||||
|
numQueries = 0 |
||||
|
failureStates = list() |
||||
|
while numQueries <= maxQueries: |
||||
|
if i >= 500: |
||||
|
if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6)) |
||||
|
if plotting: plotter.takeScreenshot(iteration, prefix="random_testing") |
||||
|
if plotting: plotter.turnCamera() |
||||
|
if plotting: input("") |
||||
|
print(f"{numQueries} {len(failureStates)} ") |
||||
|
i = 0 |
||||
|
testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest() |
||||
|
i += queriesForThisTestCase |
||||
|
numQueries += queriesForThisTestCase |
||||
|
stateValuation = all_states[testCase] |
||||
|
if testResult == Verdict.FAIL: |
||||
|
failureStates.append(stateValuation) |
||||
|
|
||||
|
print(f"{numQueries} {len(failureStates)} ") |
||||
|
|
||||
|
def parseArgs(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument('--tra', type=str, required=True, help='Path to .tra file.') |
||||
|
parser.add_argument('--lab', type=str, required=True, help='Path to .lab file.') |
||||
|
parser.add_argument('--rew', type=str, required=True, help='Path to .rew file.') |
||||
|
parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.') |
||||
|
|
||||
|
refinement = parser.add_mutually_exclusive_group(required=True) |
||||
|
refinement.add_argument('--refinement-steps', type=int, default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.') |
||||
|
refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.') |
||||
|
|
||||
|
parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.') |
||||
|
parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.') |
||||
|
|
||||
|
random_testing = parser.add_mutually_exclusive_group() |
||||
|
random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.") |
||||
|
random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.') |
||||
|
|
||||
|
parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.') |
||||
|
parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.') |
||||
|
return parser.parse_args() |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
args = parseArgs() |
||||
|
|
||||
|
traFile = args.tra |
||||
|
labFile = args.lab |
||||
|
straFile = args.stra |
||||
|
rewFile = args.rew |
||||
|
|
||||
|
ablationTesting = args.ablation |
||||
|
plotting = args.plotting |
||||
|
stepwisePlotting = args.stepwise |
||||
|
|
||||
|
maxQueriesForRandomTesting = args.random |
||||
|
horizonBound = args.bound |
||||
|
|
||||
|
refinementSteps = args.refinement_steps |
||||
|
refinementBound = args.refinement_bound |
||||
|
|
||||
|
if maxQueriesForRandomTesting == 0: #akward way to test for this... |
||||
|
main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting) |
||||
|
else: |
||||
|
randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting) |
@ -0,0 +1,135 @@ |
|||||
|
#!/usr/bin/python3 |
||||
|
|
||||
|
import sys |
||||
|
#import re |
||||
|
import json |
||||
|
import numpy as np |
||||
|
from collections import deque |
||||
|
|
||||
|
from dataclasses import dataclass |
||||
|
|
||||
|
|
||||
|
@dataclass(frozen=False) |
||||
|
class StateAction: |
||||
|
state_id: int |
||||
|
action_id: int |
||||
|
next_state_probabilities: dict |
||||
|
action_name: str |
||||
|
|
||||
|
def normalizeDistribution(self): |
||||
|
weight = np.sum([value for key, value in self.next_state_probabilities.items()]) |
||||
|
self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() } |
||||
|
|
||||
|
|
||||
|
def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId): |
||||
|
current_state_id = -1 |
||||
|
current_action_id = -1 |
||||
|
all_state_action_pairs = list() |
||||
|
with open(traFile) as transitions: |
||||
|
next(transitions) |
||||
|
for line in transitions: |
||||
|
line = line.replace("\n","") |
||||
|
explode = line.split(" ") |
||||
|
if len(explode) < 2: continue |
||||
|
interval = json.loads(explode[3]) |
||||
|
probability = (interval[0] + interval[1])/2 |
||||
|
state_id = int(explode[0]) |
||||
|
action_id = int(explode[1]) |
||||
|
next_state_id = int(explode[2]) |
||||
|
if len(explode) >= 5: |
||||
|
action_name = explode[4] |
||||
|
else: |
||||
|
action_name = "" |
||||
|
#print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.") |
||||
|
if state_id in [0, 1, 2]: |
||||
|
continue |
||||
|
|
||||
|
next_state_probabilities = {next_state_id: probability} |
||||
|
if current_state_id != state_id: |
||||
|
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) |
||||
|
all_state_action_pairs.append(new_state_action_pair) |
||||
|
current_state_id = state_id |
||||
|
current_action_id = action_id |
||||
|
|
||||
|
elif current_action_id != action_id: |
||||
|
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) |
||||
|
all_state_action_pairs.append(new_state_action_pair) |
||||
|
current_action_id = action_id |
||||
|
else: |
||||
|
all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability |
||||
|
|
||||
|
# we need to sort the deadlock and reached states to insert them while building the .tra file |
||||
|
deadlockStates = [(state, 0) for state in deadlockStates] |
||||
|
reachedStates = [(state, maxStateId) for state in reachedStates] |
||||
|
final_states = deadlockStates + reachedStates |
||||
|
final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True)) |
||||
|
|
||||
|
with open("MDP_" + traFile, "w") as new_transitions_file: |
||||
|
new_transitions_file.write("mdp\n") |
||||
|
new_transitions_file.write(f"0 0 {maxStateId} 1.0\n") |
||||
|
for entry in all_state_action_pairs: |
||||
|
entry.normalizeDistribution() |
||||
|
source_state = int(entry.state_id) |
||||
|
while final_states and int(final_states[-1][0]) < source_state: |
||||
|
final_state = final_states.pop() |
||||
|
if int(final_state[0]) == 0: continue |
||||
|
new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n") |
||||
|
for next_state_id, probability in entry.next_state_probabilities.items(): |
||||
|
new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n") |
||||
|
new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n") |
||||
|
|
||||
|
state_to_actions = dict() |
||||
|
for state_action_pair in all_state_action_pairs: |
||||
|
if state_action_pair.state_id in state_to_actions: |
||||
|
state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name) |
||||
|
else: |
||||
|
state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name] |
||||
|
return state_to_actions, all_state_action_pairs |
||||
|
|
||||
|
def readLabels(labFile): |
||||
|
deadlockStates = list() |
||||
|
reachedStates = list() |
||||
|
with open(labFile) as states: |
||||
|
newLabFile = "MDP_" + labFile |
||||
|
newLabels = open(newLabFile, "w") |
||||
|
optRewards = open(newLabFile + ".optrew", "w") |
||||
|
safetyRewards = open(newLabFile + ".saferew", "w") |
||||
|
labels = ["init", "deadlock", "reached", "failed"] |
||||
|
next(states) |
||||
|
newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n") |
||||
|
maxStateId = -1 |
||||
|
for line in states: |
||||
|
line = line.replace(":","").replace("\n", "") |
||||
|
explode = line.split(" ") |
||||
|
newLabel = f"{explode[0]} " |
||||
|
if int(explode[0]) > maxStateId: maxStateId = int(explode[0]) |
||||
|
if int(explode[0]) == 0: |
||||
|
safetyRewards.write(f"{explode[0]} -100\n") |
||||
|
optRewards.write(f"{explode[0]} -100\n") |
||||
|
#if "3" in explode: |
||||
|
# safetyRewards.write(f"{explode[0]} -100\n") |
||||
|
# optRewards.write(f"{explode[0]} -100\n") |
||||
|
elif "2" in explode: |
||||
|
optRewards.write(f"{explode[0]} 100\n") |
||||
|
else: |
||||
|
optRewards.write(f"{explode[0]} -1\n") |
||||
|
for labelIndex in explode[1:]: |
||||
|
# sink states should not be deadlock states anymore: |
||||
|
if labelIndex == "1": |
||||
|
deadlockStates.append(int(explode[0])) |
||||
|
continue |
||||
|
if labelIndex == "2": |
||||
|
reachedStates.append(int(explode[0])) |
||||
|
continue |
||||
|
newLabel += f"{labels[int(labelIndex)]} " |
||||
|
newLabels.write(newLabel + "\n") |
||||
|
return deadlockStates, reachedStates, maxStateId + 1 |
||||
|
|
||||
|
def main(traFile, labFile): |
||||
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile) |
||||
|
translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
traFile = sys.argv[1] |
||||
|
labFile = sys.argv[2] |
||||
|
main(traFile, labFile) |
Write
Preview
Loading…
Cancel
Save
Reference in new issue