diff --git a/all_positions_v2.pickle b/all_positions_v2.pickle new file mode 100644 index 0000000..79d3f0c Binary files /dev/null and b/all_positions_v2.pickle differ diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..3d83a5e --- /dev/null +++ b/evaluate.py @@ -0,0 +1,118 @@ +import time, re, sys, csv, os +import gym +from PIL import Image +from copy import deepcopy +from dataclasses import dataclass, field +import numpy as np + +from matplotlib import pyplot as plt +import readchar + +def string_to_action(action): + if action == "left": + return 2 + if action == "right": + return 1 + if action == "noop": + return 0 + return 0 + +scheduler_file = "x80_y128_pos8.sched" +def convert(tuples): + return dict(tuples) +@dataclass(frozen=True) +class State: + x: int + y: int + ski_position: int + + +def parse_scheduler(scheduler_file): + scheduler = dict() + try: + with open(scheduler_file, "r") as f: + file_content = f.readlines() + for line in file_content: + if not "move=0" in line: continue + stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) + #print("stateMapping", stateMapping) + choice = re.findall(r"{(left|right|noop)}", line) + if choice: choice = choice[0] + #print("choice", choice) + state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) + scheduler[state] = choice + return scheduler + + except EnvironmentError: + print("TODO file not available. Exiting.") + sys.exit(1) + +env = gym.make("ALE/Skiing-v5")#, render_mode="human") +#env = gym.wrappers.ResizeObservation(env, (84, 84)) +#env = gym.wrappers.GrayScaleObservation(env) + + +observation, info = env.reset() +y = 40 +standstillcounter = 0 +def update_y(y, ski_position): + y_update = 0 + global standstillcounter + if ski_position in [6,7, 8,9]: + standstillcounter = 0 + y_update = 16 + elif ski_position in [4,5, 10,11]: + standstillcounter = 0 + y_update = 12 + elif ski_position in [2,3, 12,13]: + standstillcounter = 0 + y_update = 8 + elif ski_position in [1, 14] and standstillcounter >= 5: + if standstillcounter >= 8: + print("!!!!!!!!!! no more x updates!!!!!!!!!!!") + y_update = 0 + elif ski_position in [1, 14]: + y_update = 4 + + if ski_position in [1, 14]: + standstillcounter += 1 + return y_update + +def update_ski_position(ski_position, action): + if action == 0: + return ski_position + elif action == 1: + return min(ski_position+1, 14) + elif action == 2: + return max(ski_position-1, 1) + +approx_x_coordinate = 80 +ski_position = 8 + +#scheduler = parse_scheduler(scheduler_file) +j = 0 +for _ in range(1000000): + j += 1 + #action = env.action_space.sample() # agent policy that uses the observation and info + #action = int(repr(readchar.readchar())[1]) + #action = string_to_action(scheduler.get(State(approx_x_coordinate, y, ski_position), "noop")) + action = 0 + #ski_position = update_ski_position(ski_position, action) + #y_update = update_y(y, ski_position) + #y += y_update if y_update else 0 + + #old_x = deepcopy(approx_x_coordinate) + #approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) + #print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") + observation, reward, terminated, truncated, info = env.step(action) + if terminated or truncated: + observation, info = env.reset() + break + + img = Image.fromarray(observation) + img.save(f"images/{j:05}.png") + #observation, reward, terminated, truncated, info = env.step(0) + #observation, reward, terminated, truncated, info = env.step(0) + #observation, reward, terminated, truncated, info = env.step(0) + #observation, reward, terminated, truncated, info = env.step(0) +env.close() diff --git a/init.png b/init.png new file mode 100644 index 0000000..1f8ebde Binary files /dev/null and b/init.png differ diff --git a/install.sh b/install.sh new file mode 100755 index 0000000..100a2f3 --- /dev/null +++ b/install.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +# aptitude dependencies +sudo apt install python3.8-venv python3-tk +python3 -m pip install --user virtualenv +python3 -m venv env + +source env/bin/activate +which python3 diff --git a/query_sample_factory_checkpoint.py b/query_sample_factory_checkpoint.py new file mode 100644 index 0000000..97d9520 --- /dev/null +++ b/query_sample_factory_checkpoint.py @@ -0,0 +1,69 @@ +import time +from collections import deque +from typing import Dict, Tuple + +import gymnasium as gym +import numpy as np +import torch +from torch import Tensor + +from sample_factory.algo.learning.learner import Learner +from sample_factory.algo.sampling.batched_sampling import preprocess_actions +from sample_factory.algo.utils.action_distributions import argmax_actions +from sample_factory.algo.utils.env_info import extract_env_info +from sample_factory.algo.utils.make_env import make_env_func_batched +from sample_factory.algo.utils.misc import ExperimentStatus +from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs +from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor +from sample_factory.cfg.arguments import load_from_checkpoint +from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf +from sample_factory.model.actor_critic import create_actor_critic +from sample_factory.model.model_utils import get_rnn_size +from sample_factory.utils.attr_dict import AttrDict +from sample_factory.utils.typing import Config, StatusCode +from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log + +from sf_examples.atari.train_atari import parse_atari_args, register_atari_components + +class SampleFactoryNNQueryWrapper: + def setup(self): + register_atari_components() + cfg = parse_atari_args() + actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO + actor_critic.eval() + + device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda") + actor_critic.model_to_device(device) + + policy_id = 0 #cfg.policy_index + #name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] + name_prefix = "best" + checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") + checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...) + actor_critic.load_state_dict(checkpoint_dict["model"]) + + rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device) + + self.rnn_states = rnn_states + self.actor_critic = actor_critic + + def __init__(self): + self.setup() + + + def query(self, obs): + with torch.no_grad(): + normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs) + policy_outputs = self.actor_critic(normalized_obs, self.rnn_states) + + # sample actions from the distribution by default + actions = policy_outputs["actions"] + + action_distribution = self.actor_critic.action_distribution() + actions = argmax_actions(action_distribution) + + if actions.ndim == 1: + actions = unsqueeze_tensor(actions, dim=-1) + + rnn_states = policy_outputs["new_rnn_states"] + return actions[0][0].item() diff --git a/rom_evaluate.py b/rom_evaluate.py new file mode 100644 index 0000000..bb4b420 --- /dev/null +++ b/rom_evaluate.py @@ -0,0 +1,189 @@ +import sys +from random import randrange +from ale_py import ALEInterface, SDL_SUPPORT, Action +from colors import * +from PIL import Image +from matplotlib import pyplot as plt +import cv2 +import pickle +import queue + +from copy import deepcopy + +import numpy as np + +import readchar + +from sample_factory.algo.utils.tensor_dict import TensorDict +from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper + +import time + + +def input_to_action(char): + if char == "0": + return Action.NOOP + if char == "1": + return Action.RIGHT + if char == "2": + return Action.LEFT + if char == "3": + return "reset" + if char == "4": + return "set_x" + if char == "5": + return "set_vel" + if char in ["w", "a", "s", "d"]: + return char + +ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } + +def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200): + print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}") + for i, r in enumerate(ramDICT[y]): + ale.setRAM(i,r) + ski_position_setting = ski_position_counter[ski_position] + for i in range(0,ski_position_setting[1]): + ale.act(ski_position_setting[0]) + ale.setRAM(14,0) + ale.setRAM(25,x) + ale.setRAM(14,180) + + all_obs = list() + for i in range(0,duration): + resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA) + all_obs.append(resized_obs) + if len(all_obs) >= 4: + stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) + action = nn_wrapper.query(stack_tensor) + ale.act(input_to_action(str(action))) + else: + ale.act(Action.NOOP) + time.sleep(0.005) + +ale = ALEInterface() + + +if SDL_SUPPORT: + ale.setBool("sound", True) + ale.setBool("display_screen", True) + +# Load the ROM file +rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" +ale.loadROM(rom_file) + +# Get the list of legal actions + +with open('all_positions_v2.pickle', 'rb') as handle: + ramDICT = pickle.load(handle) +#ramDICT = dict() +#for i,r in enumerate(ramDICT[235]): +# ale.setRAM(i,r) + +y_ram_setting = 60 +x = 70 + + +nn_wrapper = SampleFactoryNNQueryWrapper() +#run_single_test(ale, nn_wrapper, 70,61,5) +#input("") +run_single_test(ale, nn_wrapper, 30,61,5,duration=1000) +run_single_test(ale, nn_wrapper, 114,170,7) +run_single_test(ale, nn_wrapper, 124,170,5) +run_single_test(ale, nn_wrapper, 134,170,2) +run_single_test(ale, nn_wrapper, 120,185,1) +run_single_test(ale, nn_wrapper, 134,170,8) +run_single_test(ale, nn_wrapper, 85,195,8) +velocity_set = False +for episode in range(10): + total_reward = 0 + j = 0 + while not ale.game_over(): + if not velocity_set: ale.setRAM(14,0) + j += 1 + a = input_to_action(repr(readchar.readchar())[1]) + #a = Action.NOOP + + if a == "w": + y_ram_setting -= 1 + if y_ram_setting <= 61: + y_ram_setting = 61 + for i, r in enumerate(ramDICT[y_ram_setting]): + ale.setRAM(i,r) + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "s": + y_ram_setting += 1 + if y_ram_setting >= 1950: + y_ram_setting = 1945 + for i, r in enumerate(ramDICT[y_ram_setting]): + ale.setRAM(i,r) + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "a": + x -= 1 + if x <= 0: + x = 0 + ale.setRAM(25,x) + ale.act(Action.NOOP) + elif a == "d": + x += 1 + if x >= 144: + x = 144 + ale.setRAM(25,x) + ale.act(Action.NOOP) + + + elif a == "reset": + ram_pos = input("Ram Position:") + for i, r in enumerate(ramDICT[int(ram_pos)]): + ale.setRAM(i,r) + ale.act(Action.NOOP) + # Apply an action and get the resulting reward + elif a == "set_x": + x = int(input("X:")) + ale.setRAM(25, x) + ale.act(Action.NOOP) + elif a == "set_vel": + vel = input("Velocity:") + ale.setRAM(14, int(vel)) + ale.act(Action.NOOP) + velocity_set = True + else: + reward = ale.act(a) + ram = ale.getRAM() + #if j % 2 == 0: + # y_pixel = int(j*1/2) + 55 + # ramDICT[y_pixel] = ram + # print(f"saving to {y_pixel:04}") + # if y_pixel == 126 or y_pixel == 235: + # input("") + + int_old_ram = list(map(int, oldram)) + int_ram = list(map(int, ram)) + difference = list() + for o, r in zip(int_old_ram, int_ram): + difference.append(r-o) + + oldram = deepcopy(ram) + #print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") + print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") + #print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") + + #for i, r in enumerate(ram): + # print('{:03}:{:02x} '.format(i,r), end="") + # if i % 16 == 15: print("") + #print("") + #for i, r in enumerate(difference): + # string = '{:02}:{:03} '.format(i%100,r) + # if r != 0: + # print(color(string, fg='red'), end="") + # else: + # print(string, end="") + # if i % 16 == 15: print("") + print("Episode %d ended with score: %d" % (episode, total_reward)) + input("") + + with open('all_positions_v2.pickle', 'wb') as handle: + pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) + ale.reset_game()