You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
5.8 KiB

  1. import sys
  2. from random import randrange
  3. from ale_py import ALEInterface, SDL_SUPPORT, Action
  4. from colors import *
  5. from PIL import Image
  6. from matplotlib import pyplot as plt
  7. import cv2
  8. import pickle
  9. import queue
  10. from copy import deepcopy
  11. import numpy as np
  12. import readchar
  13. from sample_factory.algo.utils.tensor_dict import TensorDict
  14. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  15. import time
  16. def input_to_action(char):
  17. if char == "0":
  18. return Action.NOOP
  19. if char == "1":
  20. return Action.RIGHT
  21. if char == "2":
  22. return Action.LEFT
  23. if char == "3":
  24. return "reset"
  25. if char == "4":
  26. return "set_x"
  27. if char == "5":
  28. return "set_vel"
  29. if char in ["w", "a", "s", "d"]:
  30. return char
  31. ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
  32. def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
  33. print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}")
  34. for i, r in enumerate(ramDICT[y]):
  35. ale.setRAM(i,r)
  36. ski_position_setting = ski_position_counter[ski_position]
  37. for i in range(0,ski_position_setting[1]):
  38. ale.act(ski_position_setting[0])
  39. ale.setRAM(14,0)
  40. ale.setRAM(25,x)
  41. ale.setRAM(14,180)
  42. all_obs = list()
  43. for i in range(0,duration):
  44. resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA)
  45. all_obs.append(resized_obs)
  46. if len(all_obs) >= 4:
  47. stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
  48. action = nn_wrapper.query(stack_tensor)
  49. ale.act(input_to_action(str(action)))
  50. else:
  51. ale.act(Action.NOOP)
  52. time.sleep(0.005)
  53. ale = ALEInterface()
  54. if SDL_SUPPORT:
  55. ale.setBool("sound", True)
  56. ale.setBool("display_screen", True)
  57. # Load the ROM file
  58. rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
  59. ale.loadROM(rom_file)
  60. # Get the list of legal actions
  61. with open('all_positions_v2.pickle', 'rb') as handle:
  62. ramDICT = pickle.load(handle)
  63. #ramDICT = dict()
  64. #for i,r in enumerate(ramDICT[235]):
  65. # ale.setRAM(i,r)
  66. y_ram_setting = 60
  67. x = 70
  68. nn_wrapper = SampleFactoryNNQueryWrapper()
  69. #run_single_test(ale, nn_wrapper, 70,61,5)
  70. #input("")
  71. run_single_test(ale, nn_wrapper, 30,61,5,duration=1000)
  72. run_single_test(ale, nn_wrapper, 114,170,7)
  73. run_single_test(ale, nn_wrapper, 124,170,5)
  74. run_single_test(ale, nn_wrapper, 134,170,2)
  75. run_single_test(ale, nn_wrapper, 120,185,1)
  76. run_single_test(ale, nn_wrapper, 134,170,8)
  77. run_single_test(ale, nn_wrapper, 85,195,8)
  78. velocity_set = False
  79. for episode in range(10):
  80. total_reward = 0
  81. j = 0
  82. while not ale.game_over():
  83. if not velocity_set: ale.setRAM(14,0)
  84. j += 1
  85. a = input_to_action(repr(readchar.readchar())[1])
  86. #a = Action.NOOP
  87. if a == "w":
  88. y_ram_setting -= 1
  89. if y_ram_setting <= 61:
  90. y_ram_setting = 61
  91. for i, r in enumerate(ramDICT[y_ram_setting]):
  92. ale.setRAM(i,r)
  93. ale.setRAM(25,x)
  94. ale.act(Action.NOOP)
  95. elif a == "s":
  96. y_ram_setting += 1
  97. if y_ram_setting >= 1950:
  98. y_ram_setting = 1945
  99. for i, r in enumerate(ramDICT[y_ram_setting]):
  100. ale.setRAM(i,r)
  101. ale.setRAM(25,x)
  102. ale.act(Action.NOOP)
  103. elif a == "a":
  104. x -= 1
  105. if x <= 0:
  106. x = 0
  107. ale.setRAM(25,x)
  108. ale.act(Action.NOOP)
  109. elif a == "d":
  110. x += 1
  111. if x >= 144:
  112. x = 144
  113. ale.setRAM(25,x)
  114. ale.act(Action.NOOP)
  115. elif a == "reset":
  116. ram_pos = input("Ram Position:")
  117. for i, r in enumerate(ramDICT[int(ram_pos)]):
  118. ale.setRAM(i,r)
  119. ale.act(Action.NOOP)
  120. # Apply an action and get the resulting reward
  121. elif a == "set_x":
  122. x = int(input("X:"))
  123. ale.setRAM(25, x)
  124. ale.act(Action.NOOP)
  125. elif a == "set_vel":
  126. vel = input("Velocity:")
  127. ale.setRAM(14, int(vel))
  128. ale.act(Action.NOOP)
  129. velocity_set = True
  130. else:
  131. reward = ale.act(a)
  132. ram = ale.getRAM()
  133. #if j % 2 == 0:
  134. # y_pixel = int(j*1/2) + 55
  135. # ramDICT[y_pixel] = ram
  136. # print(f"saving to {y_pixel:04}")
  137. # if y_pixel == 126 or y_pixel == 235:
  138. # input("")
  139. int_old_ram = list(map(int, oldram))
  140. int_ram = list(map(int, ram))
  141. difference = list()
  142. for o, r in zip(int_old_ram, int_ram):
  143. difference.append(r-o)
  144. oldram = deepcopy(ram)
  145. #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}")
  146. print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}")
  147. #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}")
  148. #for i, r in enumerate(ram):
  149. # print('{:03}:{:02x} '.format(i,r), end="")
  150. # if i % 16 == 15: print("")
  151. #print("")
  152. #for i, r in enumerate(difference):
  153. # string = '{:02}:{:03} '.format(i%100,r)
  154. # if r != 0:
  155. # print(color(string, fg='red'), end="")
  156. # else:
  157. # print(string, end="")
  158. # if i % 16 == 15: print("")
  159. print("Episode %d ended with score: %d" % (episode, total_reward))
  160. input("")
  161. with open('all_positions_v2.pickle', 'wb') as handle:
  162. pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL)
  163. ale.reset_game()