You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.5 KiB

  1. import sys
  2. import operator
  3. from os import listdir, system
  4. from random import randrange
  5. from ale_py import ALEInterface, SDL_SUPPORT, Action
  6. from colors import *
  7. from PIL import Image
  8. from matplotlib import pyplot as plt
  9. import cv2
  10. import pickle
  11. import queue
  12. from dataclasses import dataclass, field
  13. from enum import Enum
  14. from copy import deepcopy
  15. import numpy as np
  16. import readchar
  17. from sample_factory.algo.utils.tensor_dict import TensorDict
  18. from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
  19. import time
  20. tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
  21. rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
  22. def input_to_action(char):
  23. if char == "0":
  24. return Action.NOOP
  25. if char == "1":
  26. return Action.RIGHT
  27. if char == "2":
  28. return Action.LEFT
  29. if char == "3":
  30. return "reset"
  31. if char == "4":
  32. return "set_x"
  33. if char == "5":
  34. return "set_vel"
  35. if char in ["w", "a", "s", "d"]:
  36. return char
  37. ale = ALEInterface()
  38. if SDL_SUPPORT:
  39. ale.setBool("sound", True)
  40. ale.setBool("display_screen", True)
  41. # Load the ROM file
  42. ale.loadROM(rom_file)
  43. with open('all_positions_v2.pickle', 'rb') as handle:
  44. ramDICT = pickle.load(handle)
  45. y_ram_setting = 60
  46. x = 70
  47. oldram = deepcopy(ale.getRAM())
  48. velocity_set = False
  49. for episode in range(10):
  50. total_reward = 0
  51. j = 0
  52. while not ale.game_over():
  53. if not velocity_set: ale.setRAM(14,0)
  54. j += 1
  55. a = input_to_action(repr(readchar.readchar())[1])
  56. #a = Action.NOOP
  57. if a == "w":
  58. y_ram_setting -= 1
  59. if y_ram_setting <= 61:
  60. y_ram_setting = 61
  61. for i, r in enumerate(ramDICT[y_ram_setting]):
  62. ale.setRAM(i,r)
  63. ale.setRAM(25,x)
  64. ale.act(Action.NOOP)
  65. elif a == "s":
  66. y_ram_setting += 1
  67. if y_ram_setting >= 1950:
  68. y_ram_setting = 1945
  69. for i, r in enumerate(ramDICT[y_ram_setting]):
  70. ale.setRAM(i,r)
  71. ale.setRAM(25,x)
  72. ale.act(Action.NOOP)
  73. elif a == "a":
  74. x -= 1
  75. if x <= 0:
  76. x = 0
  77. ale.setRAM(25,x)
  78. ale.act(Action.NOOP)
  79. elif a == "d":
  80. x += 1
  81. if x >= 144:
  82. x = 144
  83. ale.setRAM(25,x)
  84. ale.act(Action.NOOP)
  85. elif a == "reset":
  86. ram_pos = input("Ram Position:")
  87. for i, r in enumerate(ramDICT[int(ram_pos)]):
  88. ale.setRAM(i,r)
  89. ale.act(Action.NOOP)
  90. # Apply an action and get the resulting reward
  91. elif a == "set_x":
  92. x = int(input("X:"))
  93. ale.setRAM(25, x)
  94. ale.act(Action.NOOP)
  95. elif a == "set_vel":
  96. vel = input("Velocity:")
  97. ale.setRAM(14, int(vel))
  98. ale.act(Action.NOOP)
  99. velocity_set = True
  100. else:
  101. reward = ale.act(a)
  102. ram = ale.getRAM()
  103. #if j % 2 == 0:
  104. # y_pixel = int(j*1/2) + 55
  105. # ramDICT[y_pixel] = ram
  106. # print(f"saving to {y_pixel:04}")
  107. # if y_pixel == 126 or y_pixel == 235:
  108. # input("")
  109. int_old_ram = list(map(int, oldram))
  110. int_ram = list(map(int, ram))
  111. difference = list()
  112. for o, r in zip(int_old_ram, int_ram):
  113. difference.append(r-o)
  114. oldram = deepcopy(ram)
  115. #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}")
  116. print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}")
  117. #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}")
  118. #for i, r in enumerate(ram):
  119. # print('{:03}:{:02x} '.format(i,r), end="")
  120. # if i % 16 == 15: print("")
  121. #print("")
  122. #for i, r in enumerate(difference):
  123. # string = '{:02}:{:03} '.format(i%100,r)
  124. # if r != 0:
  125. # print(color(string, fg='red'), end="")
  126. # else:
  127. # print(string, end="")
  128. # if i % 16 == 15: print("")
  129. print("Episode %d ended with score: %d" % (episode, total_reward))
  130. input("")
  131. with open('all_positions_v2.pickle', 'wb') as handle:
  132. pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL)
  133. ale.reset_game()