|
|
import sys from random import randrange from ale_py import ALEInterface, SDL_SUPPORT, Action from colors import * from PIL import Image from matplotlib import pyplot as plt import cv2 import pickle import queue
from copy import deepcopy
import numpy as np
import readchar
from sample_factory.algo.utils.tensor_dict import TensorDict from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
import time
def input_to_action(char): if char == "0": return Action.NOOP if char == "1": return Action.RIGHT if char == "2": return Action.LEFT if char == "3": return "reset" if char == "4": return "set_x" if char == "5": return "set_vel" if char in ["w", "a", "s", "d"]: return char
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200): print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}") for i, r in enumerate(ramDICT[y]): ale.setRAM(i,r) ski_position_setting = ski_position_counter[ski_position] for i in range(0,ski_position_setting[1]): ale.act(ski_position_setting[0]) ale.setRAM(14,0) ale.setRAM(25,x) ale.setRAM(14,180)
all_obs = list() for i in range(0,duration): resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA) all_obs.append(resized_obs) if len(all_obs) >= 4: stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) action = nn_wrapper.query(stack_tensor) ale.act(input_to_action(str(action))) else: ale.act(Action.NOOP) time.sleep(0.005)
ale = ALEInterface()
if SDL_SUPPORT: ale.setBool("sound", True) ale.setBool("display_screen", True)
# Load the ROM file rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" ale.loadROM(rom_file)
# Get the list of legal actions
with open('all_positions_v2.pickle', 'rb') as handle: ramDICT = pickle.load(handle) #ramDICT = dict() #for i,r in enumerate(ramDICT[235]): # ale.setRAM(i,r)
y_ram_setting = 60 x = 70
nn_wrapper = SampleFactoryNNQueryWrapper() #run_single_test(ale, nn_wrapper, 70,61,5) #input("") run_single_test(ale, nn_wrapper, 30,61,5,duration=1000) run_single_test(ale, nn_wrapper, 114,170,7) run_single_test(ale, nn_wrapper, 124,170,5) run_single_test(ale, nn_wrapper, 134,170,2) run_single_test(ale, nn_wrapper, 120,185,1) run_single_test(ale, nn_wrapper, 134,170,8) run_single_test(ale, nn_wrapper, 85,195,8) velocity_set = False for episode in range(10): total_reward = 0 j = 0 while not ale.game_over(): if not velocity_set: ale.setRAM(14,0) j += 1 a = input_to_action(repr(readchar.readchar())[1]) #a = Action.NOOP
if a == "w": y_ram_setting -= 1 if y_ram_setting <= 61: y_ram_setting = 61 for i, r in enumerate(ramDICT[y_ram_setting]): ale.setRAM(i,r) ale.setRAM(25,x) ale.act(Action.NOOP) elif a == "s": y_ram_setting += 1 if y_ram_setting >= 1950: y_ram_setting = 1945 for i, r in enumerate(ramDICT[y_ram_setting]): ale.setRAM(i,r) ale.setRAM(25,x) ale.act(Action.NOOP) elif a == "a": x -= 1 if x <= 0: x = 0 ale.setRAM(25,x) ale.act(Action.NOOP) elif a == "d": x += 1 if x >= 144: x = 144 ale.setRAM(25,x) ale.act(Action.NOOP)
elif a == "reset": ram_pos = input("Ram Position:") for i, r in enumerate(ramDICT[int(ram_pos)]): ale.setRAM(i,r) ale.act(Action.NOOP) # Apply an action and get the resulting reward elif a == "set_x": x = int(input("X:")) ale.setRAM(25, x) ale.act(Action.NOOP) elif a == "set_vel": vel = input("Velocity:") ale.setRAM(14, int(vel)) ale.act(Action.NOOP) velocity_set = True else: reward = ale.act(a) ram = ale.getRAM() #if j % 2 == 0: # y_pixel = int(j*1/2) + 55 # ramDICT[y_pixel] = ram # print(f"saving to {y_pixel:04}") # if y_pixel == 126 or y_pixel == 235: # input("")
int_old_ram = list(map(int, oldram)) int_ram = list(map(int, ram)) difference = list() for o, r in zip(int_old_ram, int_ram): difference.append(r-o)
oldram = deepcopy(ram) #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}")
#for i, r in enumerate(ram): # print('{:03}:{:02x} '.format(i,r), end="") # if i % 16 == 15: print("") #print("") #for i, r in enumerate(difference): # string = '{:02}:{:03} '.format(i%100,r) # if r != 0: # print(color(string, fg='red'), end="") # else: # print(string, end="") # if i % 16 == 15: print("") print("Episode %d ended with score: %d" % (episode, total_reward)) input("")
with open('all_positions_v2.pickle', 'wb') as handle: pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) ale.reset_game()
|