You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
135 lines
5.7 KiB
135 lines
5.7 KiB
#!/usr/bin/python3
|
|
|
|
import sys
|
|
#import re
|
|
import json
|
|
import numpy as np
|
|
from collections import deque
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass(frozen=False)
|
|
class StateAction:
|
|
state_id: int
|
|
action_id: int
|
|
next_state_probabilities: dict
|
|
action_name: str
|
|
|
|
def normalizeDistribution(self):
|
|
weight = np.sum([value for key, value in self.next_state_probabilities.items()])
|
|
self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() }
|
|
|
|
|
|
def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId):
|
|
current_state_id = -1
|
|
current_action_id = -1
|
|
all_state_action_pairs = list()
|
|
with open(traFile) as transitions:
|
|
next(transitions)
|
|
for line in transitions:
|
|
line = line.replace("\n","")
|
|
explode = line.split(" ")
|
|
if len(explode) < 2: continue
|
|
interval = json.loads(explode[3])
|
|
probability = (interval[0] + interval[1])/2
|
|
state_id = int(explode[0])
|
|
action_id = int(explode[1])
|
|
next_state_id = int(explode[2])
|
|
if len(explode) >= 5:
|
|
action_name = explode[4]
|
|
else:
|
|
action_name = ""
|
|
#print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.")
|
|
if state_id in [0, 1, 2]:
|
|
continue
|
|
|
|
next_state_probabilities = {next_state_id: probability}
|
|
if current_state_id != state_id:
|
|
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
|
|
all_state_action_pairs.append(new_state_action_pair)
|
|
current_state_id = state_id
|
|
current_action_id = action_id
|
|
|
|
elif current_action_id != action_id:
|
|
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
|
|
all_state_action_pairs.append(new_state_action_pair)
|
|
current_action_id = action_id
|
|
else:
|
|
all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability
|
|
|
|
# we need to sort the deadlock and reached states to insert them while building the .tra file
|
|
deadlockStates = [(state, 0) for state in deadlockStates]
|
|
reachedStates = [(state, maxStateId) for state in reachedStates]
|
|
final_states = deadlockStates + reachedStates
|
|
final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True))
|
|
|
|
with open("MDP_" + traFile, "w") as new_transitions_file:
|
|
new_transitions_file.write("mdp\n")
|
|
new_transitions_file.write(f"0 0 {maxStateId} 1.0\n")
|
|
for entry in all_state_action_pairs:
|
|
entry.normalizeDistribution()
|
|
source_state = int(entry.state_id)
|
|
while final_states and int(final_states[-1][0]) < source_state:
|
|
final_state = final_states.pop()
|
|
if int(final_state[0]) == 0: continue
|
|
new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n")
|
|
for next_state_id, probability in entry.next_state_probabilities.items():
|
|
new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n")
|
|
new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n")
|
|
|
|
state_to_actions = dict()
|
|
for state_action_pair in all_state_action_pairs:
|
|
if state_action_pair.state_id in state_to_actions:
|
|
state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name)
|
|
else:
|
|
state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name]
|
|
return state_to_actions, all_state_action_pairs
|
|
|
|
def readLabels(labFile):
|
|
deadlockStates = list()
|
|
reachedStates = list()
|
|
with open(labFile) as states:
|
|
newLabFile = "MDP_" + labFile
|
|
newLabels = open(newLabFile, "w")
|
|
optRewards = open(newLabFile + ".optrew", "w")
|
|
safetyRewards = open(newLabFile + ".saferew", "w")
|
|
labels = ["init", "deadlock", "reached", "failed"]
|
|
next(states)
|
|
newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n")
|
|
maxStateId = -1
|
|
for line in states:
|
|
line = line.replace(":","").replace("\n", "")
|
|
explode = line.split(" ")
|
|
newLabel = f"{explode[0]} "
|
|
if int(explode[0]) > maxStateId: maxStateId = int(explode[0])
|
|
if int(explode[0]) == 0:
|
|
safetyRewards.write(f"{explode[0]} -100\n")
|
|
optRewards.write(f"{explode[0]} -100\n")
|
|
#if "3" in explode:
|
|
# safetyRewards.write(f"{explode[0]} -100\n")
|
|
# optRewards.write(f"{explode[0]} -100\n")
|
|
elif "2" in explode:
|
|
optRewards.write(f"{explode[0]} 100\n")
|
|
else:
|
|
optRewards.write(f"{explode[0]} -1\n")
|
|
for labelIndex in explode[1:]:
|
|
# sink states should not be deadlock states anymore:
|
|
if labelIndex == "1":
|
|
deadlockStates.append(int(explode[0]))
|
|
continue
|
|
if labelIndex == "2":
|
|
reachedStates.append(int(explode[0]))
|
|
continue
|
|
newLabel += f"{labels[int(labelIndex)]} "
|
|
newLabels.write(newLabel + "\n")
|
|
return deadlockStates, reachedStates, maxStateId + 1
|
|
|
|
def main(traFile, labFile):
|
|
deadlockStates, reachedStates, maxStateId = readLabels(labFile)
|
|
translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
|
|
|
|
if __name__ == '__main__':
|
|
traFile = sys.argv[1]
|
|
labFile = sys.argv[2]
|
|
main(traFile, labFile)
|