sp
8 months ago
4 changed files with 170 additions and 8 deletions
-
BINall_positions.pickle
-
156manual_control.py
-
8test.py
-
14train.py
@ -0,0 +1,156 @@ |
|||||
|
import sys |
||||
|
import operator |
||||
|
from os import listdir, system |
||||
|
from random import randrange |
||||
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
||||
|
from colors import * |
||||
|
from PIL import Image |
||||
|
from matplotlib import pyplot as plt |
||||
|
import cv2 |
||||
|
import pickle |
||||
|
import queue |
||||
|
from dataclasses import dataclass, field |
||||
|
|
||||
|
from enum import Enum |
||||
|
|
||||
|
from copy import deepcopy |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
import readchar |
||||
|
|
||||
|
from sample_factory.algo.utils.tensor_dict import TensorDict |
||||
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
||||
|
|
||||
|
import time |
||||
|
|
||||
|
tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" |
||||
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" |
||||
|
|
||||
|
|
||||
|
def input_to_action(char): |
||||
|
if char == "0": |
||||
|
return Action.NOOP |
||||
|
if char == "1": |
||||
|
return Action.RIGHT |
||||
|
if char == "2": |
||||
|
return Action.LEFT |
||||
|
if char == "3": |
||||
|
return "reset" |
||||
|
if char == "4": |
||||
|
return "set_x" |
||||
|
if char == "5": |
||||
|
return "set_vel" |
||||
|
if char in ["w", "a", "s", "d"]: |
||||
|
return char |
||||
|
|
||||
|
|
||||
|
ale = ALEInterface() |
||||
|
|
||||
|
|
||||
|
if SDL_SUPPORT: |
||||
|
ale.setBool("sound", True) |
||||
|
ale.setBool("display_screen", True) |
||||
|
|
||||
|
# Load the ROM file |
||||
|
ale.loadROM(rom_file) |
||||
|
|
||||
|
with open('all_positions_v2.pickle', 'rb') as handle: |
||||
|
ramDICT = pickle.load(handle) |
||||
|
y_ram_setting = 60 |
||||
|
x = 70 |
||||
|
|
||||
|
oldram = deepcopy(ale.getRAM()) |
||||
|
velocity_set = False |
||||
|
for episode in range(10): |
||||
|
total_reward = 0 |
||||
|
j = 0 |
||||
|
while not ale.game_over(): |
||||
|
if not velocity_set: ale.setRAM(14,0) |
||||
|
j += 1 |
||||
|
a = input_to_action(repr(readchar.readchar())[1]) |
||||
|
#a = Action.NOOP |
||||
|
|
||||
|
if a == "w": |
||||
|
y_ram_setting -= 1 |
||||
|
if y_ram_setting <= 61: |
||||
|
y_ram_setting = 61 |
||||
|
for i, r in enumerate(ramDICT[y_ram_setting]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "s": |
||||
|
y_ram_setting += 1 |
||||
|
if y_ram_setting >= 1950: |
||||
|
y_ram_setting = 1945 |
||||
|
for i, r in enumerate(ramDICT[y_ram_setting]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "a": |
||||
|
x -= 1 |
||||
|
if x <= 0: |
||||
|
x = 0 |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "d": |
||||
|
x += 1 |
||||
|
if x >= 144: |
||||
|
x = 144 |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
|
||||
|
|
||||
|
elif a == "reset": |
||||
|
ram_pos = input("Ram Position:") |
||||
|
for i, r in enumerate(ramDICT[int(ram_pos)]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.act(Action.NOOP) |
||||
|
# Apply an action and get the resulting reward |
||||
|
elif a == "set_x": |
||||
|
x = int(input("X:")) |
||||
|
ale.setRAM(25, x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "set_vel": |
||||
|
vel = input("Velocity:") |
||||
|
ale.setRAM(14, int(vel)) |
||||
|
ale.act(Action.NOOP) |
||||
|
velocity_set = True |
||||
|
else: |
||||
|
reward = ale.act(a) |
||||
|
ram = ale.getRAM() |
||||
|
#if j % 2 == 0: |
||||
|
# y_pixel = int(j*1/2) + 55 |
||||
|
# ramDICT[y_pixel] = ram |
||||
|
# print(f"saving to {y_pixel:04}") |
||||
|
# if y_pixel == 126 or y_pixel == 235: |
||||
|
# input("") |
||||
|
|
||||
|
int_old_ram = list(map(int, oldram)) |
||||
|
int_ram = list(map(int, ram)) |
||||
|
difference = list() |
||||
|
for o, r in zip(int_old_ram, int_ram): |
||||
|
difference.append(r-o) |
||||
|
|
||||
|
oldram = deepcopy(ram) |
||||
|
#print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") |
||||
|
print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") |
||||
|
#print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") |
||||
|
|
||||
|
#for i, r in enumerate(ram): |
||||
|
# print('{:03}:{:02x} '.format(i,r), end="") |
||||
|
# if i % 16 == 15: print("") |
||||
|
#print("") |
||||
|
#for i, r in enumerate(difference): |
||||
|
# string = '{:02}:{:03} '.format(i%100,r) |
||||
|
# if r != 0: |
||||
|
# print(color(string, fg='red'), end="") |
||||
|
# else: |
||||
|
# print(string, end="") |
||||
|
# if i % 16 == 15: print("") |
||||
|
print("Episode %d ended with score: %d" % (episode, total_reward)) |
||||
|
input("") |
||||
|
|
||||
|
with open('all_positions_v2.pickle', 'wb') as handle: |
||||
|
pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) |
||||
|
ale.reset_game() |
@ -0,0 +1,14 @@ |
|||||
|
from stable_baselines3 import PPO, DQN |
||||
|
from stable_baselines3.common.monitor import Monitor |
||||
|
from stable_baselines3.common.logger import configure, Image |
||||
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback |
||||
|
from gym_minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper, MiniWrapper |
||||
|
|
||||
|
import os |
||||
|
from subprocess import call |
||||
|
import time |
||||
|
import argparse |
||||
|
import gym |
||||
|
|
||||
|
env = gym.make("ALE/Skiing-v5", render_mode="human") |
||||
|
observation, info = env.reset() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue