diff --git a/all_positions.pickle b/all_positions.pickle new file mode 100644 index 0000000..6f8fee3 Binary files /dev/null and b/all_positions.pickle differ diff --git a/manual_control.py b/manual_control.py new file mode 100644 index 0000000..384accd --- /dev/null +++ b/manual_control.py @@ -0,0 +1,156 @@ +import sys +import operator +from os import listdir, system +from random import randrange +from ale_py import ALEInterface, SDL_SUPPORT, Action +from colors import * +from PIL import Image +from matplotlib import pyplot as plt +import cv2 +import pickle +import queue +from dataclasses import dataclass, field + +from enum import Enum + +from copy import deepcopy + +import numpy as np + +import readchar + +from sample_factory.algo.utils.tensor_dict import TensorDict +from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper + +import time + +tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" +rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" + + +def input_to_action(char): + if char == "0": + return Action.NOOP + if char == "1": + return Action.RIGHT + if char == "2": + return Action.LEFT + if char == "3": + return "reset" + if char == "4": + return "set_x" + if char == "5": + return "set_vel" + if char in ["w", "a", "s", "d"]: + return char + + +ale = ALEInterface() + + +if SDL_SUPPORT: + ale.setBool("sound", True) + ale.setBool("display_screen", True) + +# Load the ROM file +ale.loadROM(rom_file) + +with open('all_positions_v2.pickle', 'rb') as handle: + ramDICT = pickle.load(handle) +y_ram_setting = 60 +x = 70 + +oldram = deepcopy(ale.getRAM()) +velocity_set = False +for episode in range(10): + total_reward = 0 + j = 0 + while not ale.game_over(): + if not velocity_set: ale.setRAM(14,0) + j += 1 + a = input_to_action(repr(readchar.readchar())[1]) + #a = Action.NOOP + + if a == "w": + y_ram_setting -= 1 + if y_ram_setting <= 61: + y_ram_setting = 61 + for i, r in enumerate(ramDICT[y_ram_setting]): + ale.setRAM(i,r) + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "s": + y_ram_setting += 1 + if y_ram_setting >= 1950: + y_ram_setting = 1945 + for i, r in enumerate(ramDICT[y_ram_setting]): + ale.setRAM(i,r) + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "a": + x -= 1 + if x <= 0: + x = 0 + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "d": + x += 1 + if x >= 144: + x = 144 + ale.setRAM(25,x) + ale.act(Action.NOOP) + + + elif a == "reset": + ram_pos = input("Ram Position:") + for i, r in enumerate(ramDICT[int(ram_pos)]): + ale.setRAM(i,r) + ale.act(Action.NOOP) + # Apply an action and get the resulting reward + elif a == "set_x": + x = int(input("X:")) + ale.setRAM(25, x) + ale.act(Action.NOOP) + elif a == "set_vel": + vel = input("Velocity:") + ale.setRAM(14, int(vel)) + ale.act(Action.NOOP) + velocity_set = True + else: + reward = ale.act(a) + ram = ale.getRAM() + #if j % 2 == 0: + # y_pixel = int(j*1/2) + 55 + # ramDICT[y_pixel] = ram + # print(f"saving to {y_pixel:04}") + # if y_pixel == 126 or y_pixel == 235: + # input("") + + int_old_ram = list(map(int, oldram)) + int_ram = list(map(int, ram)) + difference = list() + for o, r in zip(int_old_ram, int_ram): + difference.append(r-o) + + oldram = deepcopy(ram) + #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") + print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") + #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") + + #for i, r in enumerate(ram): + # print('{:03}:{:02x} '.format(i,r), end="") + # if i % 16 == 15: print("") + #print("") + #for i, r in enumerate(difference): + # string = '{:02}:{:03} '.format(i%100,r) + # if r != 0: + # print(color(string, fg='red'), end="") + # else: + # print(string, end="") + # if i % 16 == 15: print("") + print("Episode %d ended with score: %d" % (episode, total_reward)) + input("") + + with open('all_positions_v2.pickle', 'wb') as handle: + pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) + ale.reset_game() diff --git a/test.py b/test.py index 6401f41..e1c236b 100644 --- a/test.py +++ b/test.py @@ -6,9 +6,6 @@ import numpy as np from matplotlib import pyplot as plt import readchar -import queue - -ski_position_queue = queue.Queue() env = gym.make("ALE/Skiing-v5", render_mode="human") @@ -70,9 +67,4 @@ for _ in range(1000000): observation, reward, terminated, truncated, info = env.step(0) observation, reward, terminated, truncated, info = env.step(0) - #plt.imshow(observation) - #plt.show() - #im = Image.fromarray(observation) - #im.save("init.png") - env.close() diff --git a/train.py b/train.py new file mode 100644 index 0000000..8021c12 --- /dev/null +++ b/train.py @@ -0,0 +1,14 @@ +from stable_baselines3 import PPO, DQN +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.logger import configure, Image +from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback +from gym_minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper, MiniWrapper + +import os +from subprocess import call +import time +import argparse +import gym + +env = gym.make("ALE/Skiing-v5", render_mode="human") +observation, info = env.reset()