You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
9.1 KiB

  1. import sys
  2. import operator
  3. from os import listdir, system
  4. from random import randrange
  5. from ale_py import ALEInterface, SDL_SUPPORT, Action
  6. from colors import *
  7. from PIL import Image
  8. from matplotlib import pyplot as plt
  9. import cv2
  10. import pickle
  11. import queue
  12. from dataclasses import dataclass, field
  13. from enum import Enum
  14. from copy import deepcopy
  15. import numpy as np
  16. import readchar
  17. from sample_factory.algo.utils.tensor_dict import TensorDict
  18. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  19. import time
  20. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  21. rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
  22. class Verdict(Enum):
  23. INCONCLUSIVE = 1
  24. GOOD = 2
  25. BAD = 3
  26. verdict_to_color_map = {Verdict.BAD: "200,0,0", Verdict.INCONCLUSIVE: "40,40,200", Verdict.GOOD: "00,200,100"}
  27. def convert(tuples):
  28. return dict(tuples)
  29. @dataclass(frozen=True)
  30. class State:
  31. x: int
  32. y: int
  33. ski_position: int
  34. def default_value():
  35. return {'action' : None, 'choiceValue' : None}
  36. @dataclass(frozen=True)
  37. class StateValue:
  38. ranking: float
  39. choices: dict = field(default_factory=default_value)
  40. def exec(command,verbose=True):
  41. if verbose: print(f"Executing {command}")
  42. system(f"echo {command} >> list_of_exec")
  43. return system(command)
  44. def model_to_actual(ski_position):
  45. if ski_position == 1:
  46. return 1
  47. elif ski_position in [2,3]:
  48. return 2
  49. elif ski_position in [4,5]:
  50. return 3
  51. elif ski_position in [6,7]:
  52. return 4
  53. elif ski_position in [8,9]:
  54. return 5
  55. elif ski_position in [10,11]:
  56. return 6
  57. elif ski_position in [12,13]:
  58. return 7
  59. elif ski_position == 14:
  60. return 8
  61. def input_to_action(char):
  62. if char == "0":
  63. return Action.NOOP
  64. if char == "1":
  65. return Action.RIGHT
  66. if char == "2":
  67. return Action.LEFT
  68. if char == "3":
  69. return "reset"
  70. if char == "4":
  71. return "set_x"
  72. if char == "5":
  73. return "set_vel"
  74. if char in ["w", "a", "s", "d"]:
  75. return char
  76. def drawImportantStates(important_states):
  77. draw_commands = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list(), 9:list(), 10:list(), 11:list(), 12:list(), 13:list(), 14:list()}
  78. for state in important_states:
  79. x = state[0].x
  80. y = state[0].y
  81. markerSize = 2
  82. ski_position = state[0].ski_position
  83. draw_commands[ski_position].append(f"-fill 'rgba(255,204,0,{state[1].ranking})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '")
  84. for i in range(1,15):
  85. command = f"convert images/1_full_scaled_down.png {' '.join(draw_commands[i])} first_try_{i:02}.png"
  86. exec(command)
  87. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  88. def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
  89. #print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}", end="")
  90. for i, r in enumerate(ramDICT[y]):
  91. ale.setRAM(i,r)
  92. ski_position_setting = ski_position_counter[ski_position]
  93. for i in range(0,ski_position_setting[1]):
  94. ale.act(ski_position_setting[0])
  95. ale.setRAM(14,0)
  96. ale.setRAM(25,x)
  97. ale.setRAM(14,180)
  98. all_obs = list()
  99. speed_list = list()
  100. first_action_set = False
  101. first_action = 0
  102. for i in range(0,duration):
  103. resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA)
  104. all_obs.append(resized_obs)
  105. if len(all_obs) >= 4:
  106. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  107. action = nn_wrapper.query(stack_tensor)
  108. if not first_action_set:
  109. first_action_set = True
  110. first_action = input_to_action(str(action))
  111. ale.act(input_to_action(str(action)))
  112. else:
  113. ale.act(Action.NOOP)
  114. speed_list.append(ale.getRAM()[14])
  115. if len(speed_list) > 15 and sum(speed_list[-6:-1]) == 0:
  116. return (Verdict.BAD, first_action)
  117. #time.sleep(0.005)
  118. return (Verdict.INCONCLUSIVE, first_action)
  119. def optimalAction(choices):
  120. return max(choices.items(), key=operator.itemgetter(1))[0]
  121. def computeStateRanking(mdp_file):
  122. command = f"{tempest_binary} --prism {mdp_file} --buildchoicelab --buildstateval --prop 'Rmax=? [C <= 1000]'"
  123. exec(command)
  124. def fillStateRanking(file_name, match=""):
  125. state_ranking = dict()
  126. try:
  127. with open(file_name, "r") as f:
  128. file_content = f.readlines()
  129. for line in file_content:
  130. if not "move=0" in line: continue
  131. stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
  132. #print("stateMapping", stateMapping)
  133. choices = convert(re.findall(r"[a-zA-Z_]*(left|right|noop)[a-zA-Z_]*:(-?\d+\.?\d*)", line))
  134. choices = {key:float(value) for (key,value) in choices.items()}
  135. #print("choices", choices)
  136. ranking_value = float(re.search(r"Value:([+-]?(\d*\.\d+)|\d+)", line)[0].replace("Value:",""))
  137. #print("ranking_value", ranking_value)
  138. state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
  139. value = StateValue(ranking_value, choices)
  140. state_ranking[state] = value
  141. return state_ranking
  142. except EnvironmentError:
  143. print("Ranking file not available. Exiting.")
  144. sys.exit(1)
  145. fixed_left_states = list()
  146. fixed_right_states = list()
  147. fixed_noop_states = list()
  148. def populate_fixed_actions(state, action):
  149. if action == Action.LEFT:
  150. fixed_left_states.append(state)
  151. if action == Action.RIGHT:
  152. fixed_right_states.append(state)
  153. if action == Action.NOOP:
  154. fixed_noop_states.append(state)
  155. def update_prism_file(old_prism_file, new_prism_file):
  156. fixed_left_formula = "formula Fixed_Left = false "
  157. fixed_right_formula = "formula Fixed_Right = false "
  158. fixed_noop_formula = "formula Fixed_Noop = false "
  159. for state in fixed_left_states:
  160. fixed_left_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  161. for state in fixed_right_states:
  162. fixed_right_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  163. for state in fixed_noop_states:
  164. fixed_noop_formula += f" | (x={state.x}&y={state.y}&ski_position={state.ski_position}) "
  165. fixed_left_formula += ";\n"
  166. fixed_right_formula += ";\n"
  167. fixed_noop_formula += ";\n"
  168. with open(f'{old_prism_file}', 'r') as file :
  169. filedata = file.read()
  170. if len(fixed_left_states) > 0: filedata = re.sub(r"^formula Fixed_Left =.*$", fixed_left_formula, filedata, flags=re.MULTILINE)
  171. if len(fixed_right_states) > 0: filedata = re.sub(r"^formula Fixed_Right =.*$", fixed_right_formula, filedata, flags=re.MULTILINE)
  172. if len(fixed_noop_states) > 0: filedata = re.sub(r"^formula Fixed_Noop =.*$", fixed_noop_formula, filedata, flags=re.MULTILINE)
  173. with open(f'{new_prism_file}', 'w') as file:
  174. file.write(filedata)
  175. ale = ALEInterface()
  176. #if SDL_SUPPORT:
  177. # ale.setBool("sound", True)
  178. # ale.setBool("display_screen", True)
  179. # Load the ROM file
  180. ale.loadROM(rom_file)
  181. with open('all_positions_v2.pickle', 'rb') as handle:
  182. ramDICT = pickle.load(handle)
  183. y_ram_setting = 60
  184. x = 70
  185. nn_wrapper = SampleFactoryNNQueryWrapper()
  186. iteration = 0
  187. id = int(time.time())
  188. init_mdp = "velocity"
  189. exec(f"mkdir -p images/testing_{id}")
  190. exec(f"cp 1_full_scaled_down.png images/testing_{id}/testing_0000.png")
  191. exec(f"cp {init_mdp}.prism {init_mdp}_000.prism")
  192. markerSize = 1
  193. markerList = {1: list(), 2:list(), 3:list(), 4:list(), 5:list(), 6:list(), 7:list(), 8:list()}
  194. while True:
  195. computeStateRanking(f"{init_mdp}_{iteration:03}.prism")
  196. ranking = fillStateRanking("action_ranking")
  197. sorted_ranking = sorted(ranking.items(), key=lambda x: x[1].ranking)
  198. for important_state in sorted_ranking[-100:-1]:
  199. optimal_choice = optimalAction(important_state[1].choices)
  200. #print(important_state[1].choices, f"\t\tOptimal: {optimal_choice}")
  201. x = important_state[0].x
  202. y = important_state[0].y
  203. ski_pos = model_to_actual(important_state[0].ski_position)
  204. result = run_single_test(ale,nn_wrapper,x,y,ski_pos, duration=50)
  205. #print(f".... {result}")
  206. marker = f"-fill 'rgba({verdict_to_color_map[result[0]],0.7})' -draw 'rectangle {x-markerSize},{y-markerSize} {x+markerSize},{y+markerSize} '"
  207. markerList[ski_pos].append(marker)
  208. populate_fixed_actions(important_state[0], result[1])
  209. for pos, marker in markerList.items():
  210. command = f"convert images/testing_{id}/testing_0000.png {' '.join(marker)} images/testing_{id}/testing_{iteration+1:03}_{pos:02}.png"
  211. exec(command, verbose=False)
  212. exec(f"montage images/testing_{id}/testing_{iteration+1:03}_*png -geometry +0+0 -tile x1 images/testing_{id}/{iteration+1:03}.png", verbose=False)
  213. iteration += 1
  214. update_prism_file(f"{init_mdp}_{iteration-1:03}.prism", f"{init_mdp}_{iteration:03}.prism")