4 changed files with 170 additions and 8 deletions
			
			
		- 
					BINall_positions.pickle
- 
					156manual_control.py
- 
					8test.py
- 
					14train.py
| @ -0,0 +1,156 @@ | |||
| import sys | |||
| import operator | |||
| from os import listdir, system | |||
| from random import randrange | |||
| from ale_py import ALEInterface, SDL_SUPPORT, Action | |||
| from colors import * | |||
| from PIL import Image | |||
| from matplotlib import pyplot as plt | |||
| import cv2 | |||
| import pickle | |||
| import queue | |||
| from dataclasses import dataclass, field | |||
| 
 | |||
| from enum import Enum | |||
| 
 | |||
| from copy import deepcopy | |||
| 
 | |||
| import numpy as np | |||
| 
 | |||
| import readchar | |||
| 
 | |||
| from sample_factory.algo.utils.tensor_dict import TensorDict | |||
| from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper | |||
| 
 | |||
| import time | |||
| 
 | |||
| tempest_binary = "/home/spranger/projects/tempest-devel/ranking_release/bin/storm" | |||
| rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" | |||
| 
 | |||
| 
 | |||
| def input_to_action(char): | |||
|     if char == "0": | |||
|         return Action.NOOP | |||
|     if char == "1": | |||
|         return Action.RIGHT | |||
|     if char == "2": | |||
|         return Action.LEFT | |||
|     if char == "3": | |||
|         return "reset" | |||
|     if char == "4": | |||
|         return "set_x" | |||
|     if char == "5": | |||
|         return "set_vel" | |||
|     if char in ["w", "a", "s", "d"]: | |||
|         return char | |||
| 
 | |||
| 
 | |||
| ale = ALEInterface() | |||
| 
 | |||
| 
 | |||
| if SDL_SUPPORT: | |||
|     ale.setBool("sound", True) | |||
|     ale.setBool("display_screen", True) | |||
| 
 | |||
| # Load the ROM file | |||
| ale.loadROM(rom_file) | |||
| 
 | |||
| with open('all_positions_v2.pickle', 'rb') as handle: | |||
|     ramDICT = pickle.load(handle) | |||
| y_ram_setting = 60 | |||
| x = 70 | |||
| 
 | |||
| oldram = deepcopy(ale.getRAM()) | |||
| velocity_set = False | |||
| for episode in range(10): | |||
|     total_reward = 0 | |||
|     j = 0 | |||
|     while not ale.game_over(): | |||
|         if not velocity_set: ale.setRAM(14,0) | |||
|         j += 1 | |||
|         a = input_to_action(repr(readchar.readchar())[1]) | |||
|         #a = Action.NOOP | |||
| 
 | |||
|         if a == "w": | |||
|             y_ram_setting -= 1 | |||
|             if y_ram_setting <= 61: | |||
|                 y_ram_setting = 61 | |||
|             for i, r in enumerate(ramDICT[y_ram_setting]): | |||
|                 ale.setRAM(i,r) | |||
|             ale.setRAM(25,x) | |||
|             ale.act(Action.NOOP) | |||
|         elif a == "s": | |||
|             y_ram_setting += 1 | |||
|             if y_ram_setting >= 1950: | |||
|                 y_ram_setting = 1945 | |||
|             for i, r in enumerate(ramDICT[y_ram_setting]): | |||
|                 ale.setRAM(i,r) | |||
|             ale.setRAM(25,x) | |||
|             ale.act(Action.NOOP) | |||
|         elif a == "a": | |||
|             x -= 1 | |||
|             if x <= 0: | |||
|                 x = 0 | |||
|             ale.setRAM(25,x) | |||
|             ale.act(Action.NOOP) | |||
|         elif a == "d": | |||
|             x += 1 | |||
|             if x >= 144: | |||
|                 x = 144 | |||
|             ale.setRAM(25,x) | |||
|             ale.act(Action.NOOP) | |||
| 
 | |||
| 
 | |||
|         elif a == "reset": | |||
|             ram_pos = input("Ram Position:") | |||
|             for i, r in enumerate(ramDICT[int(ram_pos)]): | |||
|                 ale.setRAM(i,r) | |||
|             ale.act(Action.NOOP) | |||
|         # Apply an action and get the resulting reward | |||
|         elif a == "set_x": | |||
|             x = int(input("X:")) | |||
|             ale.setRAM(25, x) | |||
|             ale.act(Action.NOOP) | |||
|         elif a == "set_vel": | |||
|             vel = input("Velocity:") | |||
|             ale.setRAM(14, int(vel)) | |||
|             ale.act(Action.NOOP) | |||
|             velocity_set = True | |||
|         else: | |||
|             reward = ale.act(a) | |||
|         ram = ale.getRAM() | |||
|         #if j % 2 == 0: | |||
|         #    y_pixel = int(j*1/2) + 55 | |||
|         #    ramDICT[y_pixel] = ram | |||
|         #    print(f"saving to {y_pixel:04}") | |||
|         #    if y_pixel == 126 or y_pixel == 235: | |||
|         #        input("") | |||
| 
 | |||
|         int_old_ram = list(map(int, oldram)) | |||
|         int_ram = list(map(int, ram)) | |||
|         difference = list() | |||
|         for o, r in zip(int_old_ram, int_ram): | |||
|             difference.append(r-o) | |||
| 
 | |||
|         oldram = deepcopy(ram) | |||
|         #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") | |||
|         print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") | |||
|         #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") | |||
| 
 | |||
|         #for i, r in enumerate(ram): | |||
|         #    print('{:03}:{:02x} '.format(i,r), end="") | |||
|         #    if i % 16 == 15: print("") | |||
|         #print("") | |||
|         #for i, r in enumerate(difference): | |||
|         #    string = '{:02}:{:03} '.format(i%100,r) | |||
|         #    if r != 0: | |||
|         #        print(color(string, fg='red'), end="") | |||
|         #    else: | |||
|         #        print(string, end="") | |||
|         #    if i % 16 == 15: print("") | |||
|     print("Episode %d ended with score: %d" % (episode, total_reward)) | |||
|     input("") | |||
| 
 | |||
|     with open('all_positions_v2.pickle', 'wb') as handle: | |||
|         pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) | |||
|     ale.reset_game() | |||
| @ -0,0 +1,14 @@ | |||
| from stable_baselines3 import PPO, DQN | |||
| from stable_baselines3.common.monitor import Monitor | |||
| from stable_baselines3.common.logger import configure, Image | |||
| from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CheckpointCallback | |||
| from gym_minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper, MiniWrapper | |||
| 
 | |||
| import os | |||
| from subprocess import call | |||
| import time | |||
| import argparse | |||
| import gym | |||
| 
 | |||
| env = gym.make("ALE/Skiing-v5", render_mode="human") | |||
| observation, info = env.reset() | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue