|
|
#!/usr/bin/python3
import sys #import re import json import numpy as np from collections import deque
from dataclasses import dataclass
@dataclass(frozen=False) class StateAction: state_id: int action_id: int next_state_probabilities: dict action_name: str
def normalizeDistribution(self): weight = np.sum([value for key, value in self.next_state_probabilities.items()]) self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() }
def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId): current_state_id = -1 current_action_id = -1 all_state_action_pairs = list() with open(traFile) as transitions: next(transitions) for line in transitions: line = line.replace("\n","") explode = line.split(" ") if len(explode) < 2: continue interval = json.loads(explode[3]) probability = (interval[0] + interval[1])/2 state_id = int(explode[0]) action_id = int(explode[1]) next_state_id = int(explode[2]) if len(explode) >= 5: action_name = explode[4] else: action_name = "" #print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.") if state_id in [0, 1, 2]: continue
next_state_probabilities = {next_state_id: probability} if current_state_id != state_id: new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) all_state_action_pairs.append(new_state_action_pair) current_state_id = state_id current_action_id = action_id
elif current_action_id != action_id: new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name) all_state_action_pairs.append(new_state_action_pair) current_action_id = action_id else: all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability
# we need to sort the deadlock and reached states to insert them while building the .tra file deadlockStates = [(state, 0) for state in deadlockStates] reachedStates = [(state, maxStateId) for state in reachedStates] final_states = deadlockStates + reachedStates final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True))
with open("MDP_" + traFile, "w") as new_transitions_file: new_transitions_file.write("mdp\n") new_transitions_file.write(f"0 0 {maxStateId} 1.0\n") for entry in all_state_action_pairs: entry.normalizeDistribution() source_state = int(entry.state_id) while final_states and int(final_states[-1][0]) < source_state: final_state = final_states.pop() if int(final_state[0]) == 0: continue new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n") for next_state_id, probability in entry.next_state_probabilities.items(): new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n") new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n")
state_to_actions = dict() for state_action_pair in all_state_action_pairs: if state_action_pair.state_id in state_to_actions: state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name) else: state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name] return state_to_actions, all_state_action_pairs
def readLabels(labFile): deadlockStates = list() reachedStates = list() with open(labFile) as states: newLabFile = "MDP_" + labFile newLabels = open(newLabFile, "w") optRewards = open(newLabFile + ".optrew", "w") safetyRewards = open(newLabFile + ".saferew", "w") labels = ["init", "deadlock", "reached", "failed"] next(states) newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n") maxStateId = -1 for line in states: line = line.replace(":","").replace("\n", "") explode = line.split(" ") newLabel = f"{explode[0]} " if int(explode[0]) > maxStateId: maxStateId = int(explode[0]) if int(explode[0]) == 0: safetyRewards.write(f"{explode[0]} -100\n") optRewards.write(f"{explode[0]} -100\n") #if "3" in explode: # safetyRewards.write(f"{explode[0]} -100\n") # optRewards.write(f"{explode[0]} -100\n") elif "2" in explode: optRewards.write(f"{explode[0]} 100\n") else: optRewards.write(f"{explode[0]} -1\n") for labelIndex in explode[1:]: # sink states should not be deadlock states anymore: if labelIndex == "1": deadlockStates.append(int(explode[0])) continue if labelIndex == "2": reachedStates.append(int(explode[0])) continue newLabel += f"{labels[int(labelIndex)]} " newLabels.write(newLabel + "\n") return deadlockStates, reachedStates, maxStateId + 1
def main(traFile, labFile): deadlockStates, reachedStates, maxStateId = readLabels(labFile) translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
if __name__ == '__main__': traFile = sys.argv[1] labFile = sys.argv[2] main(traFile, labFile)
|