6 changed files with 385 additions and 0 deletions
			
			
		- 
					BINall_positions_v2.pickle
 - 
					118evaluate.py
 - 
					BINinit.png
 - 
					9install.sh
 - 
					69query_sample_factory_checkpoint.py
 - 
					189rom_evaluate.py
 
@ -0,0 +1,118 @@ | 
				
			|||
import time, re, sys, csv, os | 
				
			|||
import gym | 
				
			|||
from PIL import Image | 
				
			|||
from copy import deepcopy | 
				
			|||
from dataclasses import dataclass, field | 
				
			|||
import numpy as np | 
				
			|||
 | 
				
			|||
from matplotlib import pyplot as plt | 
				
			|||
import readchar | 
				
			|||
 | 
				
			|||
def string_to_action(action): | 
				
			|||
    if action == "left": | 
				
			|||
        return 2 | 
				
			|||
    if action == "right": | 
				
			|||
        return 1 | 
				
			|||
    if action == "noop": | 
				
			|||
        return 0 | 
				
			|||
    return 0 | 
				
			|||
 | 
				
			|||
scheduler_file = "x80_y128_pos8.sched" | 
				
			|||
def convert(tuples): | 
				
			|||
    return dict(tuples) | 
				
			|||
@dataclass(frozen=True) | 
				
			|||
class State: | 
				
			|||
    x: int | 
				
			|||
    y: int | 
				
			|||
    ski_position: int | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def parse_scheduler(scheduler_file): | 
				
			|||
    scheduler = dict() | 
				
			|||
    try: | 
				
			|||
        with open(scheduler_file, "r") as f: | 
				
			|||
            file_content = f.readlines() | 
				
			|||
        for line in file_content: | 
				
			|||
            if not "move=0" in line: continue | 
				
			|||
            stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) | 
				
			|||
            #print("stateMapping", stateMapping) | 
				
			|||
            choice = re.findall(r"{(left|right|noop)}", line) | 
				
			|||
            if choice: choice = choice[0] | 
				
			|||
            #print("choice", choice) | 
				
			|||
            state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) | 
				
			|||
            scheduler[state] = choice | 
				
			|||
        return scheduler | 
				
			|||
 | 
				
			|||
    except EnvironmentError: | 
				
			|||
        print("TODO file not available. Exiting.") | 
				
			|||
        sys.exit(1) | 
				
			|||
 | 
				
			|||
env = gym.make("ALE/Skiing-v5")#, render_mode="human") | 
				
			|||
#env = gym.wrappers.ResizeObservation(env, (84, 84)) | 
				
			|||
#env = gym.wrappers.GrayScaleObservation(env) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
observation, info = env.reset() | 
				
			|||
y = 40 | 
				
			|||
standstillcounter = 0 | 
				
			|||
def update_y(y, ski_position): | 
				
			|||
    y_update = 0 | 
				
			|||
    global standstillcounter | 
				
			|||
    if ski_position in [6,7, 8,9]: | 
				
			|||
        standstillcounter = 0 | 
				
			|||
        y_update = 16 | 
				
			|||
    elif ski_position in [4,5, 10,11]: | 
				
			|||
        standstillcounter = 0 | 
				
			|||
        y_update = 12 | 
				
			|||
    elif ski_position in [2,3, 12,13]: | 
				
			|||
        standstillcounter = 0 | 
				
			|||
        y_update = 8 | 
				
			|||
    elif ski_position in [1, 14] and standstillcounter >= 5: | 
				
			|||
        if standstillcounter >= 8: | 
				
			|||
            print("!!!!!!!!!! no more x updates!!!!!!!!!!!") | 
				
			|||
        y_update = 0 | 
				
			|||
    elif ski_position in [1, 14]: | 
				
			|||
        y_update = 4 | 
				
			|||
 | 
				
			|||
    if ski_position in [1, 14]: | 
				
			|||
        standstillcounter += 1 | 
				
			|||
    return y_update | 
				
			|||
 | 
				
			|||
def update_ski_position(ski_position, action): | 
				
			|||
    if action == 0: | 
				
			|||
        return ski_position | 
				
			|||
    elif action == 1: | 
				
			|||
        return min(ski_position+1, 14) | 
				
			|||
    elif action == 2: | 
				
			|||
        return max(ski_position-1, 1) | 
				
			|||
 | 
				
			|||
approx_x_coordinate = 80 | 
				
			|||
ski_position = 8 | 
				
			|||
 | 
				
			|||
#scheduler = parse_scheduler(scheduler_file) | 
				
			|||
j = 0 | 
				
			|||
for _ in range(1000000): | 
				
			|||
    j += 1 | 
				
			|||
    #action = env.action_space.sample()  # agent policy that uses the observation and info | 
				
			|||
    #action = int(repr(readchar.readchar())[1]) | 
				
			|||
    #action = string_to_action(scheduler.get(State(approx_x_coordinate, y, ski_position), "noop")) | 
				
			|||
    action = 0 | 
				
			|||
    #ski_position = update_ski_position(ski_position, action) | 
				
			|||
    #y_update = update_y(y, ski_position) | 
				
			|||
    #y += y_update if y_update else 0 | 
				
			|||
 | 
				
			|||
    #old_x = deepcopy(approx_x_coordinate) | 
				
			|||
    #approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) | 
				
			|||
    #print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") | 
				
			|||
    observation, reward, terminated, truncated, info = env.step(action) | 
				
			|||
    if terminated or truncated: | 
				
			|||
        observation, info = env.reset() | 
				
			|||
        break | 
				
			|||
 | 
				
			|||
    img = Image.fromarray(observation) | 
				
			|||
    img.save(f"images/{j:05}.png") | 
				
			|||
    #observation, reward, terminated, truncated, info = env.step(0) | 
				
			|||
    #observation, reward, terminated, truncated, info = env.step(0) | 
				
			|||
    #observation, reward, terminated, truncated, info = env.step(0) | 
				
			|||
    #observation, reward, terminated, truncated, info = env.step(0) | 
				
			|||
env.close() | 
				
			|||
| 
		 After Width: 160 | Height: 210 | Size: 1.0 KiB  | 
@ -0,0 +1,9 @@ | 
				
			|||
#!/bin/bash | 
				
			|||
 | 
				
			|||
# aptitude dependencies | 
				
			|||
sudo apt install python3.8-venv python3-tk | 
				
			|||
python3 -m pip install --user virtualenv | 
				
			|||
python3 -m venv env | 
				
			|||
 | 
				
			|||
source env/bin/activate | 
				
			|||
which python3 | 
				
			|||
@ -0,0 +1,69 @@ | 
				
			|||
import time | 
				
			|||
from collections import deque | 
				
			|||
from typing import Dict, Tuple | 
				
			|||
 | 
				
			|||
import gymnasium as gym | 
				
			|||
import numpy as np | 
				
			|||
import torch | 
				
			|||
from torch import Tensor | 
				
			|||
 | 
				
			|||
from sample_factory.algo.learning.learner import Learner | 
				
			|||
from sample_factory.algo.sampling.batched_sampling import preprocess_actions | 
				
			|||
from sample_factory.algo.utils.action_distributions import argmax_actions | 
				
			|||
from sample_factory.algo.utils.env_info import extract_env_info | 
				
			|||
from sample_factory.algo.utils.make_env import make_env_func_batched | 
				
			|||
from sample_factory.algo.utils.misc import ExperimentStatus | 
				
			|||
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs | 
				
			|||
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor | 
				
			|||
from sample_factory.cfg.arguments import load_from_checkpoint | 
				
			|||
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf | 
				
			|||
from sample_factory.model.actor_critic import create_actor_critic | 
				
			|||
from sample_factory.model.model_utils import get_rnn_size | 
				
			|||
from sample_factory.utils.attr_dict import AttrDict | 
				
			|||
from sample_factory.utils.typing import Config, StatusCode | 
				
			|||
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log | 
				
			|||
 | 
				
			|||
from sf_examples.atari.train_atari import parse_atari_args, register_atari_components | 
				
			|||
 | 
				
			|||
class SampleFactoryNNQueryWrapper: | 
				
			|||
    def setup(self): | 
				
			|||
        register_atari_components() | 
				
			|||
        cfg = parse_atari_args() | 
				
			|||
        actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO | 
				
			|||
        actor_critic.eval() | 
				
			|||
 | 
				
			|||
        device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda") | 
				
			|||
        actor_critic.model_to_device(device) | 
				
			|||
 | 
				
			|||
        policy_id = 0 #cfg.policy_index | 
				
			|||
        #name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] | 
				
			|||
        name_prefix = "best" | 
				
			|||
        checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") | 
				
			|||
        checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...) | 
				
			|||
        actor_critic.load_state_dict(checkpoint_dict["model"]) | 
				
			|||
 | 
				
			|||
        rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device) | 
				
			|||
 | 
				
			|||
        self.rnn_states = rnn_states | 
				
			|||
        self.actor_critic = actor_critic | 
				
			|||
 | 
				
			|||
    def __init__(self): | 
				
			|||
        self.setup() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
    def query(self, obs): | 
				
			|||
        with torch.no_grad(): | 
				
			|||
            normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs) | 
				
			|||
            policy_outputs = self.actor_critic(normalized_obs, self.rnn_states) | 
				
			|||
 | 
				
			|||
            # sample actions from the distribution by default | 
				
			|||
            actions = policy_outputs["actions"] | 
				
			|||
 | 
				
			|||
            action_distribution = self.actor_critic.action_distribution() | 
				
			|||
            actions = argmax_actions(action_distribution) | 
				
			|||
 | 
				
			|||
            if actions.ndim == 1: | 
				
			|||
                actions = unsqueeze_tensor(actions, dim=-1) | 
				
			|||
 | 
				
			|||
            rnn_states = policy_outputs["new_rnn_states"] | 
				
			|||
            return actions[0][0].item() | 
				
			|||
@ -0,0 +1,189 @@ | 
				
			|||
import sys | 
				
			|||
from random import randrange | 
				
			|||
from ale_py import ALEInterface, SDL_SUPPORT, Action | 
				
			|||
from colors import * | 
				
			|||
from PIL import Image | 
				
			|||
from matplotlib import pyplot as plt | 
				
			|||
import cv2 | 
				
			|||
import pickle | 
				
			|||
import queue | 
				
			|||
 | 
				
			|||
from copy import deepcopy | 
				
			|||
 | 
				
			|||
import numpy as np | 
				
			|||
 | 
				
			|||
import readchar | 
				
			|||
 | 
				
			|||
from sample_factory.algo.utils.tensor_dict import TensorDict | 
				
			|||
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper | 
				
			|||
 | 
				
			|||
import time | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def input_to_action(char): | 
				
			|||
    if char == "0": | 
				
			|||
        return Action.NOOP | 
				
			|||
    if char == "1": | 
				
			|||
        return Action.RIGHT | 
				
			|||
    if char == "2": | 
				
			|||
        return Action.LEFT | 
				
			|||
    if char == "3": | 
				
			|||
        return "reset" | 
				
			|||
    if char == "4": | 
				
			|||
        return "set_x" | 
				
			|||
    if char == "5": | 
				
			|||
        return "set_vel" | 
				
			|||
    if char in ["w", "a", "s", "d"]: | 
				
			|||
        return char | 
				
			|||
 | 
				
			|||
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } | 
				
			|||
 | 
				
			|||
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200): | 
				
			|||
    print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}") | 
				
			|||
    for i, r in enumerate(ramDICT[y]): | 
				
			|||
        ale.setRAM(i,r) | 
				
			|||
    ski_position_setting = ski_position_counter[ski_position] | 
				
			|||
    for i in range(0,ski_position_setting[1]): | 
				
			|||
        ale.act(ski_position_setting[0]) | 
				
			|||
        ale.setRAM(14,0) | 
				
			|||
        ale.setRAM(25,x) | 
				
			|||
    ale.setRAM(14,180) | 
				
			|||
 | 
				
			|||
    all_obs = list() | 
				
			|||
    for i in range(0,duration): | 
				
			|||
        resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA) | 
				
			|||
        all_obs.append(resized_obs) | 
				
			|||
        if len(all_obs) >= 4: | 
				
			|||
            stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) | 
				
			|||
            action = nn_wrapper.query(stack_tensor) | 
				
			|||
            ale.act(input_to_action(str(action))) | 
				
			|||
        else: | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        time.sleep(0.005) | 
				
			|||
 | 
				
			|||
ale = ALEInterface() | 
				
			|||
 | 
				
			|||
 | 
				
			|||
if SDL_SUPPORT: | 
				
			|||
    ale.setBool("sound", True) | 
				
			|||
    ale.setBool("display_screen", True) | 
				
			|||
 | 
				
			|||
# Load the ROM file | 
				
			|||
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" | 
				
			|||
ale.loadROM(rom_file) | 
				
			|||
 | 
				
			|||
# Get the list of legal actions | 
				
			|||
 | 
				
			|||
with open('all_positions_v2.pickle', 'rb') as handle: | 
				
			|||
    ramDICT = pickle.load(handle) | 
				
			|||
#ramDICT = dict() | 
				
			|||
#for i,r in enumerate(ramDICT[235]): | 
				
			|||
#    ale.setRAM(i,r) | 
				
			|||
 | 
				
			|||
y_ram_setting = 60 | 
				
			|||
x = 70 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
nn_wrapper = SampleFactoryNNQueryWrapper() | 
				
			|||
#run_single_test(ale, nn_wrapper, 70,61,5) | 
				
			|||
#input("") | 
				
			|||
run_single_test(ale, nn_wrapper, 30,61,5,duration=1000) | 
				
			|||
run_single_test(ale, nn_wrapper, 114,170,7) | 
				
			|||
run_single_test(ale, nn_wrapper, 124,170,5) | 
				
			|||
run_single_test(ale, nn_wrapper, 134,170,2) | 
				
			|||
run_single_test(ale, nn_wrapper, 120,185,1) | 
				
			|||
run_single_test(ale, nn_wrapper, 134,170,8) | 
				
			|||
run_single_test(ale, nn_wrapper, 85,195,8) | 
				
			|||
velocity_set = False | 
				
			|||
for episode in range(10): | 
				
			|||
    total_reward = 0 | 
				
			|||
    j = 0 | 
				
			|||
    while not ale.game_over(): | 
				
			|||
        if not velocity_set: ale.setRAM(14,0) | 
				
			|||
        j += 1 | 
				
			|||
        a = input_to_action(repr(readchar.readchar())[1]) | 
				
			|||
        #a = Action.NOOP | 
				
			|||
 | 
				
			|||
        if a == "w": | 
				
			|||
            y_ram_setting -= 1 | 
				
			|||
            if y_ram_setting <= 61: | 
				
			|||
                y_ram_setting = 61 | 
				
			|||
            for i, r in enumerate(ramDICT[y_ram_setting]): | 
				
			|||
                ale.setRAM(i,r) | 
				
			|||
            ale.setRAM(25,x) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        elif a == "s": | 
				
			|||
            y_ram_setting += 1 | 
				
			|||
            if y_ram_setting >= 1950: | 
				
			|||
                y_ram_setting = 1945 | 
				
			|||
            for i, r in enumerate(ramDICT[y_ram_setting]): | 
				
			|||
                ale.setRAM(i,r) | 
				
			|||
            ale.setRAM(25,x) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        elif a == "a": | 
				
			|||
            x -= 1 | 
				
			|||
            if x <= 0: | 
				
			|||
                x = 0 | 
				
			|||
            ale.setRAM(25,x) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        elif a == "d": | 
				
			|||
            x += 1 | 
				
			|||
            if x >= 144: | 
				
			|||
                x = 144 | 
				
			|||
            ale.setRAM(25,x) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
        elif a == "reset": | 
				
			|||
            ram_pos = input("Ram Position:") | 
				
			|||
            for i, r in enumerate(ramDICT[int(ram_pos)]): | 
				
			|||
                ale.setRAM(i,r) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        # Apply an action and get the resulting reward | 
				
			|||
        elif a == "set_x": | 
				
			|||
            x = int(input("X:")) | 
				
			|||
            ale.setRAM(25, x) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
        elif a == "set_vel": | 
				
			|||
            vel = input("Velocity:") | 
				
			|||
            ale.setRAM(14, int(vel)) | 
				
			|||
            ale.act(Action.NOOP) | 
				
			|||
            velocity_set = True | 
				
			|||
        else: | 
				
			|||
            reward = ale.act(a) | 
				
			|||
        ram = ale.getRAM() | 
				
			|||
        #if j % 2 == 0: | 
				
			|||
        #    y_pixel = int(j*1/2) + 55 | 
				
			|||
        #    ramDICT[y_pixel] = ram | 
				
			|||
        #    print(f"saving to {y_pixel:04}") | 
				
			|||
        #    if y_pixel == 126 or y_pixel == 235: | 
				
			|||
        #        input("") | 
				
			|||
 | 
				
			|||
        int_old_ram = list(map(int, oldram)) | 
				
			|||
        int_ram = list(map(int, ram)) | 
				
			|||
        difference = list() | 
				
			|||
        for o, r in zip(int_old_ram, int_ram): | 
				
			|||
            difference.append(r-o) | 
				
			|||
 | 
				
			|||
        oldram = deepcopy(ram) | 
				
			|||
        #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") | 
				
			|||
        print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") | 
				
			|||
        #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") | 
				
			|||
 | 
				
			|||
        #for i, r in enumerate(ram): | 
				
			|||
        #    print('{:03}:{:02x} '.format(i,r), end="") | 
				
			|||
        #    if i % 16 == 15: print("") | 
				
			|||
        #print("") | 
				
			|||
        #for i, r in enumerate(difference): | 
				
			|||
        #    string = '{:02}:{:03} '.format(i%100,r) | 
				
			|||
        #    if r != 0: | 
				
			|||
        #        print(color(string, fg='red'), end="") | 
				
			|||
        #    else: | 
				
			|||
        #        print(string, end="") | 
				
			|||
        #    if i % 16 == 15: print("") | 
				
			|||
    print("Episode %d ended with score: %d" % (episode, total_reward)) | 
				
			|||
    input("") | 
				
			|||
 | 
				
			|||
    with open('all_positions_v2.pickle', 'wb') as handle: | 
				
			|||
        pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) | 
				
			|||
    ale.reset_game() | 
				
			|||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue