sp
1 year ago
6 changed files with 385 additions and 0 deletions
-
BINall_positions_v2.pickle
-
118evaluate.py
-
BINinit.png
-
9install.sh
-
69query_sample_factory_checkpoint.py
-
189rom_evaluate.py
@ -0,0 +1,118 @@ |
|||||
|
import time, re, sys, csv, os |
||||
|
import gym |
||||
|
from PIL import Image |
||||
|
from copy import deepcopy |
||||
|
from dataclasses import dataclass, field |
||||
|
import numpy as np |
||||
|
|
||||
|
from matplotlib import pyplot as plt |
||||
|
import readchar |
||||
|
|
||||
|
def string_to_action(action): |
||||
|
if action == "left": |
||||
|
return 2 |
||||
|
if action == "right": |
||||
|
return 1 |
||||
|
if action == "noop": |
||||
|
return 0 |
||||
|
return 0 |
||||
|
|
||||
|
scheduler_file = "x80_y128_pos8.sched" |
||||
|
def convert(tuples): |
||||
|
return dict(tuples) |
||||
|
@dataclass(frozen=True) |
||||
|
class State: |
||||
|
x: int |
||||
|
y: int |
||||
|
ski_position: int |
||||
|
|
||||
|
|
||||
|
def parse_scheduler(scheduler_file): |
||||
|
scheduler = dict() |
||||
|
try: |
||||
|
with open(scheduler_file, "r") as f: |
||||
|
file_content = f.readlines() |
||||
|
for line in file_content: |
||||
|
if not "move=0" in line: continue |
||||
|
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line)) |
||||
|
#print("stateMapping", stateMapping) |
||||
|
choice = re.findall(r"{(left|right|noop)}", line) |
||||
|
if choice: choice = choice[0] |
||||
|
#print("choice", choice) |
||||
|
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"])) |
||||
|
scheduler[state] = choice |
||||
|
return scheduler |
||||
|
|
||||
|
except EnvironmentError: |
||||
|
print("TODO file not available. Exiting.") |
||||
|
sys.exit(1) |
||||
|
|
||||
|
env = gym.make("ALE/Skiing-v5")#, render_mode="human") |
||||
|
#env = gym.wrappers.ResizeObservation(env, (84, 84)) |
||||
|
#env = gym.wrappers.GrayScaleObservation(env) |
||||
|
|
||||
|
|
||||
|
observation, info = env.reset() |
||||
|
y = 40 |
||||
|
standstillcounter = 0 |
||||
|
def update_y(y, ski_position): |
||||
|
y_update = 0 |
||||
|
global standstillcounter |
||||
|
if ski_position in [6,7, 8,9]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 16 |
||||
|
elif ski_position in [4,5, 10,11]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 12 |
||||
|
elif ski_position in [2,3, 12,13]: |
||||
|
standstillcounter = 0 |
||||
|
y_update = 8 |
||||
|
elif ski_position in [1, 14] and standstillcounter >= 5: |
||||
|
if standstillcounter >= 8: |
||||
|
print("!!!!!!!!!! no more x updates!!!!!!!!!!!") |
||||
|
y_update = 0 |
||||
|
elif ski_position in [1, 14]: |
||||
|
y_update = 4 |
||||
|
|
||||
|
if ski_position in [1, 14]: |
||||
|
standstillcounter += 1 |
||||
|
return y_update |
||||
|
|
||||
|
def update_ski_position(ski_position, action): |
||||
|
if action == 0: |
||||
|
return ski_position |
||||
|
elif action == 1: |
||||
|
return min(ski_position+1, 14) |
||||
|
elif action == 2: |
||||
|
return max(ski_position-1, 1) |
||||
|
|
||||
|
approx_x_coordinate = 80 |
||||
|
ski_position = 8 |
||||
|
|
||||
|
#scheduler = parse_scheduler(scheduler_file) |
||||
|
j = 0 |
||||
|
for _ in range(1000000): |
||||
|
j += 1 |
||||
|
#action = env.action_space.sample() # agent policy that uses the observation and info |
||||
|
#action = int(repr(readchar.readchar())[1]) |
||||
|
#action = string_to_action(scheduler.get(State(approx_x_coordinate, y, ski_position), "noop")) |
||||
|
action = 0 |
||||
|
#ski_position = update_ski_position(ski_position, action) |
||||
|
#y_update = update_y(y, ski_position) |
||||
|
#y += y_update if y_update else 0 |
||||
|
|
||||
|
#old_x = deepcopy(approx_x_coordinate) |
||||
|
#approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1])) |
||||
|
#print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}") |
||||
|
observation, reward, terminated, truncated, info = env.step(action) |
||||
|
if terminated or truncated: |
||||
|
observation, info = env.reset() |
||||
|
break |
||||
|
|
||||
|
img = Image.fromarray(observation) |
||||
|
img.save(f"images/{j:05}.png") |
||||
|
#observation, reward, terminated, truncated, info = env.step(0) |
||||
|
#observation, reward, terminated, truncated, info = env.step(0) |
||||
|
#observation, reward, terminated, truncated, info = env.step(0) |
||||
|
#observation, reward, terminated, truncated, info = env.step(0) |
||||
|
env.close() |
After Width: 160 | Height: 210 | Size: 1.0 KiB |
@ -0,0 +1,9 @@ |
|||||
|
#!/bin/bash |
||||
|
|
||||
|
# aptitude dependencies |
||||
|
sudo apt install python3.8-venv python3-tk |
||||
|
python3 -m pip install --user virtualenv |
||||
|
python3 -m venv env |
||||
|
|
||||
|
source env/bin/activate |
||||
|
which python3 |
@ -0,0 +1,69 @@ |
|||||
|
import time |
||||
|
from collections import deque |
||||
|
from typing import Dict, Tuple |
||||
|
|
||||
|
import gymnasium as gym |
||||
|
import numpy as np |
||||
|
import torch |
||||
|
from torch import Tensor |
||||
|
|
||||
|
from sample_factory.algo.learning.learner import Learner |
||||
|
from sample_factory.algo.sampling.batched_sampling import preprocess_actions |
||||
|
from sample_factory.algo.utils.action_distributions import argmax_actions |
||||
|
from sample_factory.algo.utils.env_info import extract_env_info |
||||
|
from sample_factory.algo.utils.make_env import make_env_func_batched |
||||
|
from sample_factory.algo.utils.misc import ExperimentStatus |
||||
|
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs |
||||
|
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor |
||||
|
from sample_factory.cfg.arguments import load_from_checkpoint |
||||
|
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf |
||||
|
from sample_factory.model.actor_critic import create_actor_critic |
||||
|
from sample_factory.model.model_utils import get_rnn_size |
||||
|
from sample_factory.utils.attr_dict import AttrDict |
||||
|
from sample_factory.utils.typing import Config, StatusCode |
||||
|
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log |
||||
|
|
||||
|
from sf_examples.atari.train_atari import parse_atari_args, register_atari_components |
||||
|
|
||||
|
class SampleFactoryNNQueryWrapper: |
||||
|
def setup(self): |
||||
|
register_atari_components() |
||||
|
cfg = parse_atari_args() |
||||
|
actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO |
||||
|
actor_critic.eval() |
||||
|
|
||||
|
device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda") |
||||
|
actor_critic.model_to_device(device) |
||||
|
|
||||
|
policy_id = 0 #cfg.policy_index |
||||
|
#name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] |
||||
|
name_prefix = "best" |
||||
|
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") |
||||
|
checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...) |
||||
|
actor_critic.load_state_dict(checkpoint_dict["model"]) |
||||
|
|
||||
|
rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device) |
||||
|
|
||||
|
self.rnn_states = rnn_states |
||||
|
self.actor_critic = actor_critic |
||||
|
|
||||
|
def __init__(self): |
||||
|
self.setup() |
||||
|
|
||||
|
|
||||
|
def query(self, obs): |
||||
|
with torch.no_grad(): |
||||
|
normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs) |
||||
|
policy_outputs = self.actor_critic(normalized_obs, self.rnn_states) |
||||
|
|
||||
|
# sample actions from the distribution by default |
||||
|
actions = policy_outputs["actions"] |
||||
|
|
||||
|
action_distribution = self.actor_critic.action_distribution() |
||||
|
actions = argmax_actions(action_distribution) |
||||
|
|
||||
|
if actions.ndim == 1: |
||||
|
actions = unsqueeze_tensor(actions, dim=-1) |
||||
|
|
||||
|
rnn_states = policy_outputs["new_rnn_states"] |
||||
|
return actions[0][0].item() |
@ -0,0 +1,189 @@ |
|||||
|
import sys |
||||
|
from random import randrange |
||||
|
from ale_py import ALEInterface, SDL_SUPPORT, Action |
||||
|
from colors import * |
||||
|
from PIL import Image |
||||
|
from matplotlib import pyplot as plt |
||||
|
import cv2 |
||||
|
import pickle |
||||
|
import queue |
||||
|
|
||||
|
from copy import deepcopy |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
import readchar |
||||
|
|
||||
|
from sample_factory.algo.utils.tensor_dict import TensorDict |
||||
|
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper |
||||
|
|
||||
|
import time |
||||
|
|
||||
|
|
||||
|
def input_to_action(char): |
||||
|
if char == "0": |
||||
|
return Action.NOOP |
||||
|
if char == "1": |
||||
|
return Action.RIGHT |
||||
|
if char == "2": |
||||
|
return Action.LEFT |
||||
|
if char == "3": |
||||
|
return "reset" |
||||
|
if char == "4": |
||||
|
return "set_x" |
||||
|
if char == "5": |
||||
|
return "set_vel" |
||||
|
if char in ["w", "a", "s", "d"]: |
||||
|
return char |
||||
|
|
||||
|
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) } |
||||
|
|
||||
|
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200): |
||||
|
print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}") |
||||
|
for i, r in enumerate(ramDICT[y]): |
||||
|
ale.setRAM(i,r) |
||||
|
ski_position_setting = ski_position_counter[ski_position] |
||||
|
for i in range(0,ski_position_setting[1]): |
||||
|
ale.act(ski_position_setting[0]) |
||||
|
ale.setRAM(14,0) |
||||
|
ale.setRAM(25,x) |
||||
|
ale.setRAM(14,180) |
||||
|
|
||||
|
all_obs = list() |
||||
|
for i in range(0,duration): |
||||
|
resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA) |
||||
|
all_obs.append(resized_obs) |
||||
|
if len(all_obs) >= 4: |
||||
|
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])}) |
||||
|
action = nn_wrapper.query(stack_tensor) |
||||
|
ale.act(input_to_action(str(action))) |
||||
|
else: |
||||
|
ale.act(Action.NOOP) |
||||
|
time.sleep(0.005) |
||||
|
|
||||
|
ale = ALEInterface() |
||||
|
|
||||
|
|
||||
|
if SDL_SUPPORT: |
||||
|
ale.setBool("sound", True) |
||||
|
ale.setBool("display_screen", True) |
||||
|
|
||||
|
# Load the ROM file |
||||
|
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin" |
||||
|
ale.loadROM(rom_file) |
||||
|
|
||||
|
# Get the list of legal actions |
||||
|
|
||||
|
with open('all_positions_v2.pickle', 'rb') as handle: |
||||
|
ramDICT = pickle.load(handle) |
||||
|
#ramDICT = dict() |
||||
|
#for i,r in enumerate(ramDICT[235]): |
||||
|
# ale.setRAM(i,r) |
||||
|
|
||||
|
y_ram_setting = 60 |
||||
|
x = 70 |
||||
|
|
||||
|
|
||||
|
nn_wrapper = SampleFactoryNNQueryWrapper() |
||||
|
#run_single_test(ale, nn_wrapper, 70,61,5) |
||||
|
#input("") |
||||
|
run_single_test(ale, nn_wrapper, 30,61,5,duration=1000) |
||||
|
run_single_test(ale, nn_wrapper, 114,170,7) |
||||
|
run_single_test(ale, nn_wrapper, 124,170,5) |
||||
|
run_single_test(ale, nn_wrapper, 134,170,2) |
||||
|
run_single_test(ale, nn_wrapper, 120,185,1) |
||||
|
run_single_test(ale, nn_wrapper, 134,170,8) |
||||
|
run_single_test(ale, nn_wrapper, 85,195,8) |
||||
|
velocity_set = False |
||||
|
for episode in range(10): |
||||
|
total_reward = 0 |
||||
|
j = 0 |
||||
|
while not ale.game_over(): |
||||
|
if not velocity_set: ale.setRAM(14,0) |
||||
|
j += 1 |
||||
|
a = input_to_action(repr(readchar.readchar())[1]) |
||||
|
#a = Action.NOOP |
||||
|
|
||||
|
if a == "w": |
||||
|
y_ram_setting -= 1 |
||||
|
if y_ram_setting <= 61: |
||||
|
y_ram_setting = 61 |
||||
|
for i, r in enumerate(ramDICT[y_ram_setting]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "s": |
||||
|
y_ram_setting += 1 |
||||
|
if y_ram_setting >= 1950: |
||||
|
y_ram_setting = 1945 |
||||
|
for i, r in enumerate(ramDICT[y_ram_setting]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "a": |
||||
|
x -= 1 |
||||
|
if x <= 0: |
||||
|
x = 0 |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "d": |
||||
|
x += 1 |
||||
|
if x >= 144: |
||||
|
x = 144 |
||||
|
ale.setRAM(25,x) |
||||
|
ale.act(Action.NOOP) |
||||
|
|
||||
|
|
||||
|
elif a == "reset": |
||||
|
ram_pos = input("Ram Position:") |
||||
|
for i, r in enumerate(ramDICT[int(ram_pos)]): |
||||
|
ale.setRAM(i,r) |
||||
|
ale.act(Action.NOOP) |
||||
|
# Apply an action and get the resulting reward |
||||
|
elif a == "set_x": |
||||
|
x = int(input("X:")) |
||||
|
ale.setRAM(25, x) |
||||
|
ale.act(Action.NOOP) |
||||
|
elif a == "set_vel": |
||||
|
vel = input("Velocity:") |
||||
|
ale.setRAM(14, int(vel)) |
||||
|
ale.act(Action.NOOP) |
||||
|
velocity_set = True |
||||
|
else: |
||||
|
reward = ale.act(a) |
||||
|
ram = ale.getRAM() |
||||
|
#if j % 2 == 0: |
||||
|
# y_pixel = int(j*1/2) + 55 |
||||
|
# ramDICT[y_pixel] = ram |
||||
|
# print(f"saving to {y_pixel:04}") |
||||
|
# if y_pixel == 126 or y_pixel == 235: |
||||
|
# input("") |
||||
|
|
||||
|
int_old_ram = list(map(int, oldram)) |
||||
|
int_ram = list(map(int, ram)) |
||||
|
difference = list() |
||||
|
for o, r in zip(int_old_ram, int_ram): |
||||
|
difference.append(r-o) |
||||
|
|
||||
|
oldram = deepcopy(ram) |
||||
|
#print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}") |
||||
|
print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}") |
||||
|
#print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}") |
||||
|
|
||||
|
#for i, r in enumerate(ram): |
||||
|
# print('{:03}:{:02x} '.format(i,r), end="") |
||||
|
# if i % 16 == 15: print("") |
||||
|
#print("") |
||||
|
#for i, r in enumerate(difference): |
||||
|
# string = '{:02}:{:03} '.format(i%100,r) |
||||
|
# if r != 0: |
||||
|
# print(color(string, fg='red'), end="") |
||||
|
# else: |
||||
|
# print(string, end="") |
||||
|
# if i % 16 == 15: print("") |
||||
|
print("Episode %d ended with score: %d" % (episode, total_reward)) |
||||
|
input("") |
||||
|
|
||||
|
with open('all_positions_v2.pickle', 'wb') as handle: |
||||
|
pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL) |
||||
|
ale.reset_game() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue