You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

135 lines
5.7 KiB

#!/usr/bin/python3
import sys
#import re
import json
import numpy as np
from collections import deque
from dataclasses import dataclass
@dataclass(frozen=False)
class StateAction:
state_id: int
action_id: int
next_state_probabilities: dict
action_name: str
def normalizeDistribution(self):
weight = np.sum([value for key, value in self.next_state_probabilities.items()])
self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() }
def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId):
current_state_id = -1
current_action_id = -1
all_state_action_pairs = list()
with open(traFile) as transitions:
next(transitions)
for line in transitions:
line = line.replace("\n","")
explode = line.split(" ")
if len(explode) < 2: continue
interval = json.loads(explode[3])
probability = (interval[0] + interval[1])/2
state_id = int(explode[0])
action_id = int(explode[1])
next_state_id = int(explode[2])
if len(explode) >= 5:
action_name = explode[4]
else:
action_name = ""
#print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.")
if state_id in [0, 1, 2]:
continue
next_state_probabilities = {next_state_id: probability}
if current_state_id != state_id:
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
all_state_action_pairs.append(new_state_action_pair)
current_state_id = state_id
current_action_id = action_id
elif current_action_id != action_id:
new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
all_state_action_pairs.append(new_state_action_pair)
current_action_id = action_id
else:
all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability
# we need to sort the deadlock and reached states to insert them while building the .tra file
deadlockStates = [(state, 0) for state in deadlockStates]
reachedStates = [(state, maxStateId) for state in reachedStates]
final_states = deadlockStates + reachedStates
final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True))
with open("MDP_" + traFile, "w") as new_transitions_file:
new_transitions_file.write("mdp\n")
new_transitions_file.write(f"0 0 {maxStateId} 1.0\n")
for entry in all_state_action_pairs:
entry.normalizeDistribution()
source_state = int(entry.state_id)
while final_states and int(final_states[-1][0]) < source_state:
final_state = final_states.pop()
if int(final_state[0]) == 0: continue
new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n")
for next_state_id, probability in entry.next_state_probabilities.items():
new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n")
new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n")
state_to_actions = dict()
for state_action_pair in all_state_action_pairs:
if state_action_pair.state_id in state_to_actions:
state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name)
else:
state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name]
return state_to_actions, all_state_action_pairs
def readLabels(labFile):
deadlockStates = list()
reachedStates = list()
with open(labFile) as states:
newLabFile = "MDP_" + labFile
newLabels = open(newLabFile, "w")
optRewards = open(newLabFile + ".optrew", "w")
safetyRewards = open(newLabFile + ".saferew", "w")
labels = ["init", "deadlock", "reached", "failed"]
next(states)
newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n")
maxStateId = -1
for line in states:
line = line.replace(":","").replace("\n", "")
explode = line.split(" ")
newLabel = f"{explode[0]} "
if int(explode[0]) > maxStateId: maxStateId = int(explode[0])
if int(explode[0]) == 0:
safetyRewards.write(f"{explode[0]} -100\n")
optRewards.write(f"{explode[0]} -100\n")
#if "3" in explode:
# safetyRewards.write(f"{explode[0]} -100\n")
# optRewards.write(f"{explode[0]} -100\n")
elif "2" in explode:
optRewards.write(f"{explode[0]} 100\n")
else:
optRewards.write(f"{explode[0]} -1\n")
for labelIndex in explode[1:]:
# sink states should not be deadlock states anymore:
if labelIndex == "1":
deadlockStates.append(int(explode[0]))
continue
if labelIndex == "2":
reachedStates.append(int(explode[0]))
continue
newLabel += f"{labels[int(labelIndex)]} "
newLabels.write(newLabel + "\n")
return deadlockStates, reachedStates, maxStateId + 1
def main(traFile, labFile):
deadlockStates, reachedStates, maxStateId = readLabels(labFile)
translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
if __name__ == '__main__':
traFile = sys.argv[1]
labFile = sys.argv[2]
main(traFile, labFile)