You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
5.7 KiB

  1. #!/usr/bin/python3
  2. import sys
  3. #import re
  4. import json
  5. import numpy as np
  6. from collections import deque
  7. from dataclasses import dataclass
  8. @dataclass(frozen=False)
  9. class StateAction:
  10. state_id: int
  11. action_id: int
  12. next_state_probabilities: dict
  13. action_name: str
  14. def normalizeDistribution(self):
  15. weight = np.sum([value for key, value in self.next_state_probabilities.items()])
  16. self.next_state_probabilities = { next_state_id : (probability / weight) for next_state_id, probability in self.next_state_probabilities.items() }
  17. def translateTransitions(traFile, deadlockStates, reachedStates, maxStateId):
  18. current_state_id = -1
  19. current_action_id = -1
  20. all_state_action_pairs = list()
  21. with open(traFile) as transitions:
  22. next(transitions)
  23. for line in transitions:
  24. line = line.replace("\n","")
  25. explode = line.split(" ")
  26. if len(explode) < 2: continue
  27. interval = json.loads(explode[3])
  28. probability = (interval[0] + interval[1])/2
  29. state_id = int(explode[0])
  30. action_id = int(explode[1])
  31. next_state_id = int(explode[2])
  32. if len(explode) >= 5:
  33. action_name = explode[4]
  34. else:
  35. action_name = ""
  36. #print(f"State : {state_id} with action {action_id} leads with {probability} to {next_state_id}.")
  37. if state_id in [0, 1, 2]:
  38. continue
  39. next_state_probabilities = {next_state_id: probability}
  40. if current_state_id != state_id:
  41. new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
  42. all_state_action_pairs.append(new_state_action_pair)
  43. current_state_id = state_id
  44. current_action_id = action_id
  45. elif current_action_id != action_id:
  46. new_state_action_pair = StateAction(state_id, action_id, next_state_probabilities, action_name)
  47. all_state_action_pairs.append(new_state_action_pair)
  48. current_action_id = action_id
  49. else:
  50. all_state_action_pairs[-1].next_state_probabilities[next_state_id] = probability
  51. # we need to sort the deadlock and reached states to insert them while building the .tra file
  52. deadlockStates = [(state, 0) for state in deadlockStates]
  53. reachedStates = [(state, maxStateId) for state in reachedStates]
  54. final_states = deadlockStates + reachedStates
  55. final_states = deque(sorted(final_states, key=lambda tuple: tuple[0], reverse=True))
  56. with open("MDP_" + traFile, "w") as new_transitions_file:
  57. new_transitions_file.write("mdp\n")
  58. new_transitions_file.write(f"0 0 {maxStateId} 1.0\n")
  59. for entry in all_state_action_pairs:
  60. entry.normalizeDistribution()
  61. source_state = int(entry.state_id)
  62. while final_states and int(final_states[-1][0]) < source_state:
  63. final_state = final_states.pop()
  64. if int(final_state[0]) == 0: continue
  65. new_transitions_file.write(f"{final_state[0]} 0 {final_state[1]} 1.0\n")
  66. for next_state_id, probability in entry.next_state_probabilities.items():
  67. new_transitions_file.write(f"{entry.state_id} {entry.action_id} {next_state_id} {probability}\n")
  68. new_transitions_file.write(f"{maxStateId} 0 {maxStateId} 1.0\n")
  69. state_to_actions = dict()
  70. for state_action_pair in all_state_action_pairs:
  71. if state_action_pair.state_id in state_to_actions:
  72. state_to_actions[state_action_pair.state_id].append(state_action_pair.action_name)
  73. else:
  74. state_to_actions[state_action_pair.state_id] = [state_action_pair.action_name]
  75. return state_to_actions, all_state_action_pairs
  76. def readLabels(labFile):
  77. deadlockStates = list()
  78. reachedStates = list()
  79. with open(labFile) as states:
  80. newLabFile = "MDP_" + labFile
  81. newLabels = open(newLabFile, "w")
  82. optRewards = open(newLabFile + ".optrew", "w")
  83. safetyRewards = open(newLabFile + ".saferew", "w")
  84. labels = ["init", "deadlock", "reached", "failed"]
  85. next(states)
  86. newLabels.write("#DECLARATION\ninit deadlock reached failed\n#END\n")
  87. maxStateId = -1
  88. for line in states:
  89. line = line.replace(":","").replace("\n", "")
  90. explode = line.split(" ")
  91. newLabel = f"{explode[0]} "
  92. if int(explode[0]) > maxStateId: maxStateId = int(explode[0])
  93. if int(explode[0]) == 0:
  94. safetyRewards.write(f"{explode[0]} -100\n")
  95. optRewards.write(f"{explode[0]} -100\n")
  96. #if "3" in explode:
  97. # safetyRewards.write(f"{explode[0]} -100\n")
  98. # optRewards.write(f"{explode[0]} -100\n")
  99. elif "2" in explode:
  100. optRewards.write(f"{explode[0]} 100\n")
  101. else:
  102. optRewards.write(f"{explode[0]} -1\n")
  103. for labelIndex in explode[1:]:
  104. # sink states should not be deadlock states anymore:
  105. if labelIndex == "1":
  106. deadlockStates.append(int(explode[0]))
  107. continue
  108. if labelIndex == "2":
  109. reachedStates.append(int(explode[0]))
  110. continue
  111. newLabel += f"{labels[int(labelIndex)]} "
  112. newLabels.write(newLabel + "\n")
  113. return deadlockStates, reachedStates, maxStateId + 1
  114. def main(traFile, labFile):
  115. deadlockStates, reachedStates, maxStateId = readLabels(labFile)
  116. translateTransitions(traFile, deadlockStates, reachedStates, maxStateId)
  117. if __name__ == '__main__':
  118. traFile = sys.argv[1]
  119. labFile = sys.argv[2]
  120. main(traFile, labFile)