commit b3be734b8facc2f67eaf1b2bbe571c2d768fa89b Author: sp Date: Tue Dec 26 10:59:37 2023 +0100 finished testing setup diff --git a/plotting.py b/plotting.py new file mode 100755 index 0000000..efb9380 --- /dev/null +++ b/plotting.py @@ -0,0 +1,88 @@ +#!/usr/bin/python3 + +import visvis as vv +import numpy as np + +import time + +def translateArrayToVV(array): + scaling = (float(array[0][1]) - float(array[0][0]), float(array[2][1]) - float(array[2][0]), float(array[4][1]) - float(array[4][0])) + translation = (float(array[0][0]) + scaling[0] * 0.5, float(array[2][0]) + scaling[0] * 0.5, float(array[4][0]) + scaling[0] * 0.5) + return translation, scaling + +class VisVisPlotter(): + def __init__(self, stateValuations, goalStates, deadlockStates, stepwise): + self.app = vv.use() + self.fig = vv.clf() + self.ax = vv.cla() + self.stateColor = (0,0,0.75,0.4) + self.failColor = (0.8,0.8,0.8,1.0) # (1,0,0,0.8) + self.goalColor = (0,1,0,1.0) + self.stateScaling = (2,2,2) #(.5,.5,.5) + + self.plotedStates = set() + self.plotedMeshes = set() + self.stepwise = stepwise + + auxStates = [0,1,2,25518] + self.goals = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in goalStates and not stateId in auxStates]) + self.fails = set([(stateValuation.x, stateValuation.y, stateValuation.z) for stateId, stateValuation in stateValuations.items() if stateId in deadlockStates and not stateId in auxStates]) + + def plotScenario(self, saveScreenshot=True): + for goal in self.goals: + state = vv.solidBox(goal, scaling=self.stateScaling) + state.faceColor = self.goalColor + for fail in self.fails: + state = vv.solidBox(fail, scaling=self.stateScaling) + state.faceColor = self.failColor + + self.ax.SetLimits((-16,16),(-10,10),(-7,7)) + self.ax.SetView({'zoom':0.025, 'elevation':20, 'azimuth':30}) + if saveScreenshot: vv.screenshot("000.png", sf=3, bg='w', ob=vv.gcf()) + + def run(self): + self.ax.SetLimits((-16,16),(-10,10),(-7,7)) + self.app.Run() + + def clear(self): + axes = vv.gca() + axes.Clear() + + def plotStates(self, states, coloring="", removeMeshes=False): + if self.stepwise and removeMeshes: + self.clear() + self.plotScenario(saveScreenshot=False) + if not coloring: + coloring = self.stateColor + plotedRegions = set() + for state in states: + if state in self.plotedStates: continue # what are the implications of this? + coordinates = (state.x, state.y, state.z) + print(f"plotting {state} at {coordinates}") + if coordinates in plotedRegions: continue + state = vv.solidBox(coordinates, scaling=(1,1,1))#(0.5,0.5,0.5)) + state.faceColor = coloring + plotedRegions.add(coordinates) + if self.stepwise: + self.plotedMeshes.add(state) + self.plotedStates.update(states) + + def takeScreenshot(self, iteration, prefix=""): + self.ax.SetLimits((-16,16),(-10,10),(-7,7)) + #config = [(90, 00), (45, 00), (0, 00), (-45, 00), (-90, 00)] + config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] + for elevation, azimuth in config: + filename = f"{prefix}{'_' if prefix else ''}{iteration:03}_{elevation}_{azimuth}.png" + print(f"Saving Screenshot to {filename}") + self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) + vv.screenshot(filename, sf=3, bg='w', ob=vv.gcf()) + + def turnCamera(self): + config = [(45, -150), (45, -100), (45, -60), (45, 60), (45, 120)] + self.ax.SetLimits((-16,16),(-10,10),(-7,7)) + for elevation, azimuth in config: + self.ax.SetView({'zoom':0.025, 'elevation':elevation, 'azimuth':azimuth}) + + +if __name__ == '__main__': + main() diff --git a/simulation.py b/simulation.py new file mode 100644 index 0000000..4dcd0dd --- /dev/null +++ b/simulation.py @@ -0,0 +1,68 @@ +import sys +from enum import Flag, auto + +import numpy as np + + +class Verdict(Flag): + INCONCLUSIVE = auto() + PASS = auto() + FAIL = auto() + + + +class Simulator(): + def __init__(self, allStateActionPairs, strategy, deadlockStates, reachedStates, bound=3, numSimulations=1): + self.allStateActionPairs = { ( pair.state_id, pair.action_id ) : pair.next_state_probabilities for pair in allStateActionPairs } + self.strategy = strategy + self.deadlockStates = deadlockStates + self.reachedStates = reachedStates + + #print(f"Deadlock: {self.deadlockStates}") + #print(f"GoalStates: {self.reachedStates}") + + + self.bound = bound + self.numSimulations = numSimulations + + allStates = set([state.state_id for state in allStateActionPairs]) + allStates = allStates.difference(set(deadlockStates)) + allStates = allStates.difference(set(reachedStates)) + self.allStates = np.array(list(allStates)) + + def _pickRandomTestCase(self): + testCase = np.random.choice(self.allStates, 1)[0] + #self.allStates = np.delete(self.allStates, testCase) + return testCase + + def _simulate(self, initialStateId): + i = 0 + + actionId = self.strategy[initialStateId] + nextStatePair = (initialStateId, actionId) + + while i < self.bound: + i += 1 + nextStateProbabilities = self.allStateActionPairs[nextStatePair] + weights = list() + nextStateIds = list() + for nextStateId, probability in nextStateProbabilities.items(): + weights.append(probability) + nextStateIds.append(nextStateId) + nextStateId = np.random.choice(nextStateIds, 1, p=weights)[0] + if nextStateId in self.deadlockStates: + return Verdict.FAIL, i + if nextStateId in self.reachedStates: + return Verdict.PASS, i + nextStatePair = (nextStateId, self.strategy[nextStateId]) + return Verdict.INCONCLUSIVE, i + + def runTest(self): + testCase = self._pickRandomTestCase() + + histogram = [0,0,0] + for i in range(self.numSimulations): + result, numQueries = self._simulate(testCase) + if result == Verdict.FAIL: + return testCase, Verdict.FAIL, numQueries + return testCase, Verdict.INCONCLUSIVE, numQueries diff --git a/test_model.py b/test_model.py new file mode 100755 index 0000000..4cf07bc --- /dev/null +++ b/test_model.py @@ -0,0 +1,376 @@ +#!/usr/bin/python3 + +import re, sys, os, shutil, fileinput, subprocess, argparse +from dataclasses import dataclass, field +import visvis as vv +import numpy as np +import argparse + +from translate import translateTransitions, readLabels +from plotting import VisVisPlotter +from simulation import Simulator, Verdict + + +def convert(tuples): + return dict(tuples) + +def getBasename(filename): + return os.path.basename(filename) + +def traFileWithIteration(filename, iteration): + return os.path.splitext(filename)[0] + f"_{iteration:03}.tra" + +def copyFile(filename, newFilename): + shutil.copy(filename, newFilename) + + +def execute(command, verbose=False): + if verbose: print(f"Executing {command}") + os.system(command) + +@dataclass(frozen=True) +class State: + id: int + x: float + x_vel: float + y: float + y_vel: float + z: float + z_vel: float + +def default_value(): + return {'action' : None, 'choiceValue' : None} + + +@dataclass(frozen=True) +class StateValue: + ranking: float + choices: dict = field(default_factory=default_value) + +@dataclass(frozen=False) +class TestResult: + prob_pes_min: float + prob_pes_max: float + prob_pes_avg: float + prob_opt_min: float + prob_opt_max: float + prob_opt_avg: float + min_min: float + min_max: float + + def csv(self, ws=" "): + return f"{self.prob_pes_min:0.04f}{ws}{self.prob_pes_max:0.04f}{ws}{self.prob_pes_avg:0.04f}{ws}{self.prob_opt_min:0.04f}{ws}{self.prob_opt_max:0.04f}{ws}{self.prob_opt_avg:0.04f}{ws}{self.min_min:0.04f}{ws}{self.min_max:0.04f}{ws}" + +def parseStrategy(strategyFile, allStateActionPairs, time_index=0): + strategy = dict() + with open(strategyFile) as strategyLines: + for line in strategyLines: + line = line.replace("(","").replace(")","").replace("\n", "") + explode = re.split(",|=", line) + stateId = int(explode[0]) + 3 + if stateId < 3: continue + if int(explode[1]) != time_index: continue + try: + strategy[stateId] = allStateActionPairs[stateId].index(explode[2]) + except KeyError as e: + pass + + return strategy + +def queryStrategy(strategy, stateId): + try: + return strategy[stateId] + except: + return -1 + + +def callTempest(files, reward, bound=3): + property_str = "!(\"failed\" | \"reached\")" + if True: + prop = f"filter(min, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(max, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(avg, Pmax=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(min, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(max, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(avg, Pmin=? [ true U<={bound} \"failed\" ], {property_str} );" + prop += f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" + prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" + else: + prop = f"filter(min, Rmin=? [ C<={bound} ], {property_str} );" + prop += f"filter(max, Rmin=? [ C<={bound} ], {property_str} );" + prop += f"filter(avg, Rmin=? [ C<={bound} ], {property_str} );" + prop += f"filter(min, Rmax=? [ C<={bound} ], {property_str} );" + prop += f"filter(max, Rmax=? [ C<={bound} ], {property_str} );" + prop += f"filter(avg, Rmax=? [ C<={bound} ], {property_str} );" + command = f"~/projects/tempest-devel/ranking_release/bin/storm --io:explicit {files} --io:staterew MDP_Abstraction_interval.lab.{reward} --prop '{prop}' " + + results = list() + try: + output = subprocess.check_output(command, shell=True).decode("utf-8").split('\n') + for line in output: + if "Result" in line and not len(results) >= 10: + range_value = re.search(r"(.*:).*\[(-?\d+\.?\d*), (-?\d+\.?\d*)\].*", line) + if range_value: + results.append(float(range_value.group(2))) + results.append(float(range_value.group(3))) + else: + value = re.search(r"(.*:)(.*)", line) + results.append(float(value.group(2))) + except subprocess.CalledProcessError as e: + print(e.output) + #results.append(-1) + #results.append(-1) + return TestResult(*(tuple(results))) + +def parseRanking(filename, allStates): + state_ranking = dict() + try: + with open(filename, "r") as f: + filecontent = f.readlines() + for line in filecontent: + stateId = int(re.findall(r"^\d+", line)[0]) + values = re.findall(r":(-?\d+\.?\d*),?", line) + ranking_value = float(values[0]) + choices = {i : float(value) for i,value in enumerate(values[1:])} + state = allStates[stateId] + value = StateValue(ranking_value, choices) + state_ranking[state] = value + if len(state_ranking) == 0: return + all_values = [x.ranking for x in state_ranking.values()] + max_value = max(all_values) + min_value = min(all_values) + new_state_ranking = {} + for state, value in state_ranking.items(): + choices = value.choices + try: + new_value = (value.ranking - min_value) / (max_value - min_value) + except ZeroDivisionError as e: + new_value = 0.0 + new_state_ranking[state] = StateValue(new_value, choices) + state_ranking = new_state_ranking + except EnvironmentError: + print("TODO file not available. Exiting.") + sys.exit(1) + return {state: values for state, values in sorted(state_ranking.items(), key=lambda item: item[1].ranking)} + +def parseStateValuations(filename): + all_states = dict() + maxStateId = -1 + for i in [0,1,2]: + dummy_values = [i] * 7 + all_states[i] = State(*dummy_values) + with open(filename) as stateValuations: + for line in stateValuations: + values = re.findall(r"(-?\d+\.?\d*),?", line) + values = [int(values[0])] + [float(v) for v in values[1:]] + all_states[values[0]] = State(*values) + if values[0] > maxStateId: maxStateId = values[0] + dummy_values = [maxStateId + 1] * 7 + all_states[maxStateId + 1] = State(*dummy_values) + return all_states + +def parseResults(allStates): + state_to_values = dict() + with open("prob_results_maximize") as maximizer, open("prob_results_minimize") as minimizer: + for max_line, min_line in zip(maximizer, minimizer): + max_values = re.findall(r"(-?\d+\.?\d*),?", max_line) + min_values = re.findall(r"(-?\d+\.?\d*),?", min_line) + if max_values[0] != min_values[0]: + print("min/max files do not match.") + assert(False) + stateId = int(max_values[0]) + min_result = float(min_values[1]) + max_result = float(max_values[1]) + value = (min_result, max_result, max_result - min_result) + state_to_values[stateId] = value + return state_to_values + + +def removeActionFromTransitionFile(stateId, chosenActionIndex, filename, iteration): + stateIdRegex = re.compile(f"^{stateId}\s") + for line in fileinput.input(filename, inplace = True): + if not stateIdRegex.match(line): + print(line, end="") + else: + explode = line.split(" ") + if int(explode[1]) == chosenActionIndex: + print(line, end="") + +def removeActionsFromTransitionFile(stateActionPairsToTrim, filename, iteration): + stateIdsRegex = re.compile("|".join([f"^{stateId}\s" for stateId, actionIndex in stateActionPairsToTrim.items()])) + for line in fileinput.input(filename, inplace = True): + result = stateIdsRegex.match(line) + if not result: + print(line, end="") + else: + actionIndex = stateActionPairsToTrim[int(result[0])] + explode = line.split(" ") + if int(explode[1]) == actionIndex: + print(line, end="") + +def getTopNStates(rankedStates, n, threshold): + if n != 0: + return dict(list(rankedStates.items())[-n:]) + else: + return {state:value for state,value in rankedStates.items() if value.ranking >= threshold} + +def getNRandomStates(rankedStates, n, testedStates): + stateIds = [state.id for state in rankedStates.keys()] + notYetTestedStates = np.array([stateId for stateId in stateIds if stateId not in testedStates]) + if len(notYetTestedStates) >= n: + return notYetTestedStates[np.random.choice(len(notYetTestedStates), size=n, replace=False)] + else: + return notYetTestedStates + +def main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting=False, stepwisePlotting=False): + + all_states = parseStateValuations("MDP_state_valuations") + deadlockStates, reachedStates, maxStateId = readLabels(labFile) + stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) + strategy = parseStrategy(straFile, stateToActions) + + if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) + if plotting: plotter.plotScenario() + + copyFile("MDP_" + traFile, "MDP_" + os.path.splitext(getBasename(traFile))[0] + f"_000.tra") + + iteration = 0 + #testsPerIteration = refinementSteps + #refinementThreshold = + numTestedStates = 0 + totalIterations = 60 + + testedStates = list() + while iteration < totalIterations: + print(f"{iteration:03}", end="\t") + sys.stdout.flush() + currentTraFile = traFileWithIteration("MDP_" + traFile, iteration) + nextTraFile = traFileWithIteration("MDP_" + traFile, iteration+1) + testResult = callTempest(f"{currentTraFile} MDP_{labFile}", "saferew", horizonBound) + state_ranking = parseRanking("action_ranking", all_states) + copyFile("action_ranking", f"action_ranking_{iteration:03}") + copyFile("prob_results_maximize", f"prob_results_maximize_{iteration:03}") + copyFile("prob_results_minimize", f"prob_results_minimize_{iteration:03}") + + if not ablationTesting: + importantStates = getTopNStates(state_ranking, refinementSteps, refinementBound) + statesToTest = [state.id for state in importantStates.keys()] + statesToPlot = importantStates + else: + statesToTest = list(getNRandomStates(state_ranking, refinementSteps, testedStates)) + testedStates += statesToTest + statesToPlot = {all_states[stateId]:StateValue(0,{}) for stateId in statesToTest} + + + copyFile(currentTraFile, nextTraFile) + stateActionPairsToTrim = dict() + for testState in statesToTest: + chosenActionIndex = queryStrategy(strategy, testState) + if chosenActionIndex != -1: + stateActionPairsToTrim[testState] = chosenActionIndex + stateEstimates = parseResults(all_states) + results = [0,0,0] + + failureStates = list() + validatedStates = list() + for state, estimates in stateEstimates.items(): + if state in deadlockStates or state in reachedStates: + continue + if estimates[0] > 0.05: + results[1] += 1 + failureStates.append(all_states[state]) + #print(f"{state}: {estimates}") + elif estimates[1] <= 0.05: + results[0] += 1 + validatedStates.append(all_states[state]) + else: + results[2] += 1 + + + removeActionsFromTransitionFile(stateActionPairsToTrim, nextTraFile, iteration) + print(f"{numTestedStates}\t{testResult.csv(' ')}\t{results[0]}\t{results[1]}\t{results[2]}\t{sum(results)}") + if results[2] == 0: + sys.exit(0) + numTestedStates += len(statesToTest) + iteration += 1 + + if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6), removeMeshes=True) + if plotting: plotter.plotStates(validatedStates, coloring=(0.0,0.8,0.0,0.6)) + if plotting: plotter.takeScreenshot(iteration, prefix="stepwise_0.05") + +def randomTesting(traFile, labFile, straFile, bound, maxQueries, plotting=False): + all_states = parseStateValuations("MDP_state_valuations") + deadlockStates, reachedStates, maxStateId = readLabels(labFile) + stateToActions, allStateActionPairs = translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) + strategy = parseStrategy(straFile, stateToActions) + + if plotting: plotter = VisVisPlotter(all_states, reachedStates, deadlockStates, stepwisePlotting) + if plotting: plotter.plotScenario() + passingStates = list() + + randomTestingSimulator = Simulator(allStateActionPairs, strategy, deadlockStates, reachedStates, bound) + i = 0 + print("Starting with random testing.") + numQueries = 0 + failureStates = list() + while numQueries <= maxQueries: + if i >= 500: + if plotting: plotter.plotStates(failureStates, coloring=(0.8,0.0,0.0,0.6)) + if plotting: plotter.takeScreenshot(iteration, prefix="random_testing") + if plotting: plotter.turnCamera() + if plotting: input("") + print(f"{numQueries} {len(failureStates)} ") + i = 0 + testCase, testResult, queriesForThisTestCase = randomTestingSimulator.runTest() + i += queriesForThisTestCase + numQueries += queriesForThisTestCase + stateValuation = all_states[testCase] + if testResult == Verdict.FAIL: + failureStates.append(stateValuation) + + print(f"{numQueries} {len(failureStates)} ") + +def parseArgs(): + parser = argparse.ArgumentParser() + parser.add_argument('--tra', type=str, required=True, help='Path to .tra file.') + parser.add_argument('--lab', type=str, required=True, help='Path to .lab file.') + parser.add_argument('--rew', type=str, required=True, help='Path to .rew file.') + parser.add_argument('--stra', type=str, required=True, help='Path to strategy file.') + + refinement = parser.add_mutually_exclusive_group(required=True) + refinement.add_argument('--refinement-steps', type=int, default=0, help='Amount of refinement steps per iteration, mutually exclusive with refinement-bound.') + refinement.add_argument('--refinement-bound', type=float, default=0, help='Threshold value for states to be tested, mutually exclusive with refinement-steps.') + + parser.add_argument('--bound', type=int, required=False, default=3, help='(optional) Safety Horizon Bound, defaults to 3.') + parser.add_argument('--threshold', type=float, required=False, default=0.05, help='(optional) Safety Threshold, defaults to 0.05.') + + random_testing = parser.add_mutually_exclusive_group() + random_testing.add_argument('-a', '--ablation', action='store_true', help="(optional) Run ablation testing for the importance ranking, i.e. model-based random testing.") + random_testing.add_argument('-r', '--random', type=int, default=0, help='(optional) The amount of queries allowed for random testing.') + + parser.add_argument('-p', '--plotting', action='store_true', help='(optional) Enable plotting.') + parser.add_argument('--stepwise', action='store_true', help='(optional) Remove states before plotting the next iteration.') + return parser.parse_args() + +if __name__ == '__main__': + args = parseArgs() + + traFile = args.tra + labFile = args.lab + straFile = args.stra + rewFile = args.rew + + ablationTesting = args.ablation + plotting = args.plotting + stepwisePlotting = args.stepwise + + maxQueriesForRandomTesting = args.random + horizonBound = args.bound + + refinementSteps = args.refinement_steps + refinementBound = args.refinement_bound + + if maxQueriesForRandomTesting == 0: #akward way to test for this... + main(traFile, labFile, straFile, horizonBound, refinementSteps, refinementBound, ablationTesting, plotting, stepwisePlotting) + else: + randomTesting(traFile, labFile, straFile, horizonBound, maxQueriesForRandomTesting, plotting) diff --git a/translate.py b/translate.py new file mode 100755 index 0000000..73ef8f9 --- /dev/null +++ b/translate.py @@ -0,0 +1,135 @@ +#!/usr/bin/python3 + +import sys +#import re +import json +import numpy as np +from collections import deque + +from dataclasses import dataclass + + +@dataclass(frozen=False) +class StateAction: + state_id: int + action_id: int + next_state_probabilities: dict + action_name: str + + def normalizeDistribution(self): + weight = np.sum([value for key, value in self.next_state_probabilities.items()]) + self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() } + + +def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId): + current_state_id = -1 + current_action_id = -1 + all_state_action_pairs = list() + with open(traFile) as transitions: + next(transitions) + for line in transitions: + line = line.replace("\n","") + explode = line.split(" ") + if len(explode) < 2: continue + interval = json.loads(explode[3]) + probability = (interval[0] + interval[1])/2 + state_id = int(explode[0]) + action_id = int(explode[1]) + next_state_id = int(explode[2]) + if len(explode) >= 5: + action_name = explode[4] + else: + action_name = "" + #print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.") + if state_id in [0, 1, 2]: + continue + + next_state_probabilities = {next_state_id: probability} + if current_state_id != state_id: + new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) + all_state_action_pairs.append(new_state_action_pair) + current_state_id = state_id + current_action_id = action_id + + elif current_action_id != action_id: + new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) + all_state_action_pairs.append(new_state_action_pair) + current_action_id = action_id + else: + all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability + + # we need to sort the deadlock and reached states to insert them while building the .tra file + deadlockStates = [(state, 0) for state in deadlockStates] + reachedStates = [(state, maxStateId) for state in reachedStates] + final_states = deadlockStates + reachedStates + final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True)) + + with open("MDP_" + traFile, "w") as new_transitions_file: + new_transitions_file.write("mdp\n") + new_transitions_file.write(f"0 0 {maxStateId} 1.0\n") + for entry in all_state_action_pairs: + entry.normalizeDistribution() + source_state = int(entry.state_id) + while final_states and int(final_states[-1][0]) < source_state: + final_state = final_states.pop() + if int(final_state[0]) == 0: continue + new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n") + for next_state_id, probability in entry.next_state_probabilities.items(): + new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n") + new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n") + + state_to_actions = dict() + for state_action_pair in all_state_action_pairs: + if state_action_pair.state_id in state_to_actions: + state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name) + else: + state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name] + return state_to_actions, all_state_action_pairs + +def readLabels(labFile): + deadlockStates = list() + reachedStates = list() + with open(labFile) as states: + newLabFile = "MDP_" + labFile + newLabels = open(newLabFile, "w") + optRewards = open(newLabFile + ".optrew", "w") + safetyRewards = open(newLabFile + ".saferew", "w") + labels = ["init", "deadlock", "reached", "failed"] + next(states) + newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n") + maxStateId = -1 + for line in states: + line = line.replace(":","").replace("\n", "") + explode = line.split(" ") + newLabel = f"{explode[0]} " + if int(explode[0]) > maxStateId: maxStateId = int(explode[0]) + if int(explode[0]) == 0: + safetyRewards.write(f"{explode[0]} -100\n") + optRewards.write(f"{explode[0]} -100\n") + #if "3" in explode: + # safetyRewards.write(f"{explode[0]} -100\n") + # optRewards.write(f"{explode[0]} -100\n") + elif "2" in explode: + optRewards.write(f"{explode[0]} 100\n") + else: + optRewards.write(f"{explode[0]} -1\n") + for labelIndex in explode[1:]: + # sink states should not be deadlock states anymore: + if labelIndex == "1": + deadlockStates.append(int(explode[0])) + continue + if labelIndex == "2": + reachedStates.append(int(explode[0])) + continue + newLabel += f"{labels[int(labelIndex)]} " + newLabels.write(newLabel + "\n") + return deadlockStates, reachedStates, maxStateId + 1 + +def main(traFile, labFile): + deadlockStates, reachedStates, maxStateId = readLabels(labFile) + translateTransitions(traFile, deadlockStates, reachedStates, maxStateId) + +if __name__ == '__main__': + traFile = sys.argv[1] + labFile = sys.argv[2] + main(traFile, labFile)