You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
156 lines
4.5 KiB
156 lines
4.5 KiB
import sys
|
|
import operator
|
|
from os import listdir, system
|
|
from random import randrange
|
|
from ale_py import ALEInterface, SDL_SUPPORT, Action
|
|
from colors import *
|
|
from PIL import Image
|
|
from matplotlib import pyplot as plt
|
|
import cv2
|
|
import pickle
|
|
import queue
|
|
from dataclasses import dataclass, field
|
|
|
|
from enum import Enum
|
|
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
|
|
import readchar
|
|
|
|
from sample_factory.algo.utils.tensor_dict import TensorDict
|
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
|
|
|
|
import time
|
|
|
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm"
|
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
|
|
|
|
|
|
def input_to_action(char):
|
|
if char == "0":
|
|
return Action.NOOP
|
|
if char == "1":
|
|
return Action.RIGHT
|
|
if char == "2":
|
|
return Action.LEFT
|
|
if char == "3":
|
|
return "reset"
|
|
if char == "4":
|
|
return "set_x"
|
|
if char == "5":
|
|
return "set_vel"
|
|
if char in ["w", "a", "s", "d"]:
|
|
return char
|
|
|
|
|
|
ale = ALEInterface()
|
|
|
|
|
|
if SDL_SUPPORT:
|
|
ale.setBool("sound", True)
|
|
ale.setBool("display_screen", True)
|
|
|
|
# Load the ROM file
|
|
ale.loadROM(rom_file)
|
|
|
|
with open('all_positions_v2.pickle', 'rb') as handle:
|
|
ramDICT = pickle.load(handle)
|
|
y_ram_setting = 60
|
|
x = 70
|
|
|
|
oldram = deepcopy(ale.getRAM())
|
|
velocity_set = False
|
|
for episode in range(10):
|
|
total_reward = 0
|
|
j = 0
|
|
while not ale.game_over():
|
|
if not velocity_set: ale.setRAM(14,0)
|
|
j += 1
|
|
a = input_to_action(repr(readchar.readchar())[1])
|
|
#a = Action.NOOP
|
|
|
|
if a == "w":
|
|
y_ram_setting -= 1
|
|
if y_ram_setting <= 61:
|
|
y_ram_setting = 61
|
|
for i, r in enumerate(ramDICT[y_ram_setting]):
|
|
ale.setRAM(i,r)
|
|
ale.setRAM(25,x)
|
|
ale.act(Action.NOOP)
|
|
elif a == "s":
|
|
y_ram_setting += 1
|
|
if y_ram_setting >= 1950:
|
|
y_ram_setting = 1945
|
|
for i, r in enumerate(ramDICT[y_ram_setting]):
|
|
ale.setRAM(i,r)
|
|
ale.setRAM(25,x)
|
|
ale.act(Action.NOOP)
|
|
elif a == "a":
|
|
x -= 1
|
|
if x <= 0:
|
|
x = 0
|
|
ale.setRAM(25,x)
|
|
ale.act(Action.NOOP)
|
|
elif a == "d":
|
|
x += 1
|
|
if x >= 144:
|
|
x = 144
|
|
ale.setRAM(25,x)
|
|
ale.act(Action.NOOP)
|
|
|
|
|
|
elif a == "reset":
|
|
ram_pos = input("Ram Position:")
|
|
for i, r in enumerate(ramDICT[int(ram_pos)]):
|
|
ale.setRAM(i,r)
|
|
ale.act(Action.NOOP)
|
|
# Apply an action and get the resulting reward
|
|
elif a == "set_x":
|
|
x = int(input("X:"))
|
|
ale.setRAM(25, x)
|
|
ale.act(Action.NOOP)
|
|
elif a == "set_vel":
|
|
vel = input("Velocity:")
|
|
ale.setRAM(14, int(vel))
|
|
ale.act(Action.NOOP)
|
|
velocity_set = True
|
|
else:
|
|
reward = ale.act(a)
|
|
ram = ale.getRAM()
|
|
#if j % 2 == 0:
|
|
# y_pixel = int(j*1/2) + 55
|
|
# ramDICT[y_pixel] = ram
|
|
# print(f"saving to {y_pixel:04}")
|
|
# if y_pixel == 126 or y_pixel == 235:
|
|
# input("")
|
|
|
|
int_old_ram = list(map(int, oldram))
|
|
int_ram = list(map(int, ram))
|
|
difference = list()
|
|
for o, r in zip(int_old_ram, int_ram):
|
|
difference.append(r-o)
|
|
|
|
oldram = deepcopy(ram)
|
|
#print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}")
|
|
print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}")
|
|
#print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}")
|
|
|
|
#for i, r in enumerate(ram):
|
|
# print('{:03}:{:02x} '.format(i,r), end="")
|
|
# if i % 16 == 15: print("")
|
|
#print("")
|
|
#for i, r in enumerate(difference):
|
|
# string = '{:02}:{:03} '.format(i%100,r)
|
|
# if r != 0:
|
|
# print(color(string, fg='red'), end="")
|
|
# else:
|
|
# print(string, end="")
|
|
# if i % 16 == 15: print("")
|
|
print("Episode %d ended with score: %d" % (episode, total_reward))
|
|
input("")
|
|
|
|
with open('all_positions_v2.pickle', 'wb') as handle:
|
|
pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL)
|
|
ale.reset_game()
|