Browse Source

checkpoint commit

main
sp 1 year ago
parent
commit
786bfae90b
  1. BIN
      all_positions_v2.pickle
  2. 118
      evaluate.py
  3. BIN
      init.png
  4. 9
      install.sh
  5. 69
      query_sample_factory_checkpoint.py
  6. 189
      rom_evaluate.py

BIN
all_positions_v2.pickle

118
evaluate.py

@ -0,0 +1,118 @@
import time, re, sys, csv, os
import gym
from PIL import Image
from copy import deepcopy
from dataclasses import dataclass, field
import numpy as np
from matplotlib import pyplot as plt
import readchar
def string_to_action(action):
if action == "left":
return 2
if action == "right":
return 1
if action == "noop":
return 0
return 0
scheduler_file = "x80_y128_pos8.sched"
def convert(tuples):
return dict(tuples)
@dataclass(frozen=True)
class State:
x: int
y: int
ski_position: int
def parse_scheduler(scheduler_file):
scheduler = dict()
try:
with open(scheduler_file, "r") as f:
file_content = f.readlines()
for line in file_content:
if not "move=0" in line: continue
stateMapping = convert(re.findall(r"([a-zA-Z_]*[a-zA-Z])=(\d+)?", line))
#print("stateMapping", stateMapping)
choice = re.findall(r"{(left|right|noop)}", line)
if choice: choice = choice[0]
#print("choice", choice)
state = State(int(stateMapping["x"]), int(stateMapping["y"]), int(stateMapping["ski_position"]))
scheduler[state] = choice
return scheduler
except EnvironmentError:
print("TODO file not available. Exiting.")
sys.exit(1)
env = gym.make("ALE/Skiing-v5")#, render_mode="human")
#env = gym.wrappers.ResizeObservation(env, (84, 84))
#env = gym.wrappers.GrayScaleObservation(env)
observation, info = env.reset()
y = 40
standstillcounter = 0
def update_y(y, ski_position):
y_update = 0
global standstillcounter
if ski_position in [6,7, 8,9]:
standstillcounter = 0
y_update = 16
elif ski_position in [4,5, 10,11]:
standstillcounter = 0
y_update = 12
elif ski_position in [2,3, 12,13]:
standstillcounter = 0
y_update = 8
elif ski_position in [1, 14] and standstillcounter >= 5:
if standstillcounter >= 8:
print("!!!!!!!!!! no more x updates!!!!!!!!!!!")
y_update = 0
elif ski_position in [1, 14]:
y_update = 4
if ski_position in [1, 14]:
standstillcounter += 1
return y_update
def update_ski_position(ski_position, action):
if action == 0:
return ski_position
elif action == 1:
return min(ski_position+1, 14)
elif action == 2:
return max(ski_position-1, 1)
approx_x_coordinate = 80
ski_position = 8
#scheduler = parse_scheduler(scheduler_file)
j = 0
for _ in range(1000000):
j += 1
#action = env.action_space.sample() # agent policy that uses the observation and info
#action = int(repr(readchar.readchar())[1])
#action = string_to_action(scheduler.get(State(approx_x_coordinate, y, ski_position), "noop"))
action = 0
#ski_position = update_ski_position(ski_position, action)
#y_update = update_y(y, ski_position)
#y += y_update if y_update else 0
#old_x = deepcopy(approx_x_coordinate)
#approx_x_coordinate = int(np.mean(np.where(observation[:,:,1] == 92)[1]))
#print(f"Action: {action},\tski position: {ski_position},\ty_update: {y_update},\ty: {y},\tx: {approx_x_coordinate},\tx_update:{approx_x_coordinate - old_x}")
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
observation, info = env.reset()
break
img = Image.fromarray(observation)
img.save(f"images/{j:05}.png")
#observation, reward, terminated, truncated, info = env.step(0)
#observation, reward, terminated, truncated, info = env.step(0)
#observation, reward, terminated, truncated, info = env.step(0)
#observation, reward, terminated, truncated, info = env.step(0)
env.close()

BIN
init.png

After

Width: 160  |  Height: 210  |  Size: 1.0 KiB

9
install.sh

@ -0,0 +1,9 @@
#!/bin/bash
# aptitude dependencies
sudo apt install python3.8-venv python3-tk
python3 -m pip install --user virtualenv
python3 -m venv env
source env/bin/activate
which python3

69
query_sample_factory_checkpoint.py

@ -0,0 +1,69 @@
import time
from collections import deque
from typing import Dict, Tuple
import gymnasium as gym
import numpy as np
import torch
from torch import Tensor
from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
from sf_examples.atari.train_atari import parse_atari_args, register_atari_components
class SampleFactoryNNQueryWrapper:
def setup(self):
register_atari_components()
cfg = parse_atari_args()
actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO
actor_critic.eval()
device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda")
actor_critic.model_to_device(device)
policy_id = 0 #cfg.policy_index
#name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
name_prefix = "best"
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...)
actor_critic.load_state_dict(checkpoint_dict["model"])
rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device)
self.rnn_states = rnn_states
self.actor_critic = actor_critic
def __init__(self):
self.setup()
def query(self, obs):
with torch.no_grad():
normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs)
policy_outputs = self.actor_critic(normalized_obs, self.rnn_states)
# sample actions from the distribution by default
actions = policy_outputs["actions"]
action_distribution = self.actor_critic.action_distribution()
actions = argmax_actions(action_distribution)
if actions.ndim == 1:
actions = unsqueeze_tensor(actions, dim=-1)
rnn_states = policy_outputs["new_rnn_states"]
return actions[0][0].item()

189
rom_evaluate.py

@ -0,0 +1,189 @@
import sys
from random import randrange
from ale_py import ALEInterface, SDL_SUPPORT, Action
from colors import *
from PIL import Image
from matplotlib import pyplot as plt
import cv2
import pickle
import queue
from copy import deepcopy
import numpy as np
import readchar
from sample_factory.algo.utils.tensor_dict import TensorDict
from query_sample_factory_checkpoint import SampleFactoryNNQueryWrapper
import time
def input_to_action(char):
if char == "0":
return Action.NOOP
if char == "1":
return Action.RIGHT
if char == "2":
return Action.LEFT
if char == "3":
return "reset"
if char == "4":
return "set_x"
if char == "5":
return "set_vel"
if char in ["w", "a", "s", "d"]:
return char
ski_position_counter = {1: (Action.LEFT, 40), 2: (Action.LEFT, 35), 3: (Action.LEFT, 30), 4: (Action.LEFT, 10), 5: (Action.NOOP, 1), 6: (Action.RIGHT, 10), 7: (Action.RIGHT, 30), 8: (Action.RIGHT, 40) }
def run_single_test(ale, nn_wrapper, x,y,ski_position, duration=200):
print(f"Running Test from x: {x:04}, y: {y:04}, ski_position: {ski_position}")
for i, r in enumerate(ramDICT[y]):
ale.setRAM(i,r)
ski_position_setting = ski_position_counter[ski_position]
for i in range(0,ski_position_setting[1]):
ale.act(ski_position_setting[0])
ale.setRAM(14,0)
ale.setRAM(25,x)
ale.setRAM(14,180)
all_obs = list()
for i in range(0,duration):
resized_obs = cv2.resize(ale.getScreenGrayscale() , (84,84), interpolation=cv2.INTER_AREA)
all_obs.append(resized_obs)
if len(all_obs) >= 4:
stack_tensor = TensorDict({"obs": np.array(all_obs[-4:])})
action = nn_wrapper.query(stack_tensor)
ale.act(input_to_action(str(action)))
else:
ale.act(Action.NOOP)
time.sleep(0.005)
ale = ALEInterface()
if SDL_SUPPORT:
ale.setBool("sound", True)
ale.setBool("display_screen", True)
# Load the ROM file
rom_file = "/home/spranger/research/Skiing/env/lib/python3.8/site-packages/AutoROM/roms/skiing.bin"
ale.loadROM(rom_file)
# Get the list of legal actions
with open('all_positions_v2.pickle', 'rb') as handle:
ramDICT = pickle.load(handle)
#ramDICT = dict()
#for i,r in enumerate(ramDICT[235]):
# ale.setRAM(i,r)
y_ram_setting = 60
x = 70
nn_wrapper = SampleFactoryNNQueryWrapper()
#run_single_test(ale, nn_wrapper, 70,61,5)
#input("")
run_single_test(ale, nn_wrapper, 30,61,5,duration=1000)
run_single_test(ale, nn_wrapper, 114,170,7)
run_single_test(ale, nn_wrapper, 124,170,5)
run_single_test(ale, nn_wrapper, 134,170,2)
run_single_test(ale, nn_wrapper, 120,185,1)
run_single_test(ale, nn_wrapper, 134,170,8)
run_single_test(ale, nn_wrapper, 85,195,8)
velocity_set = False
for episode in range(10):
total_reward = 0
j = 0
while not ale.game_over():
if not velocity_set: ale.setRAM(14,0)
j += 1
a = input_to_action(repr(readchar.readchar())[1])
#a = Action.NOOP
if a == "w":
y_ram_setting -= 1
if y_ram_setting <= 61:
y_ram_setting = 61
for i, r in enumerate(ramDICT[y_ram_setting]):
ale.setRAM(i,r)
ale.setRAM(25,x)
ale.act(Action.NOOP)
elif a == "s":
y_ram_setting += 1
if y_ram_setting >= 1950:
y_ram_setting = 1945
for i, r in enumerate(ramDICT[y_ram_setting]):
ale.setRAM(i,r)
ale.setRAM(25,x)
ale.act(Action.NOOP)
elif a == "a":
x -= 1
if x <= 0:
x = 0
ale.setRAM(25,x)
ale.act(Action.NOOP)
elif a == "d":
x += 1
if x >= 144:
x = 144
ale.setRAM(25,x)
ale.act(Action.NOOP)
elif a == "reset":
ram_pos = input("Ram Position:")
for i, r in enumerate(ramDICT[int(ram_pos)]):
ale.setRAM(i,r)
ale.act(Action.NOOP)
# Apply an action and get the resulting reward
elif a == "set_x":
x = int(input("X:"))
ale.setRAM(25, x)
ale.act(Action.NOOP)
elif a == "set_vel":
vel = input("Velocity:")
ale.setRAM(14, int(vel))
ale.act(Action.NOOP)
velocity_set = True
else:
reward = ale.act(a)
ram = ale.getRAM()
#if j % 2 == 0:
# y_pixel = int(j*1/2) + 55
# ramDICT[y_pixel] = ram
# print(f"saving to {y_pixel:04}")
# if y_pixel == 126 or y_pixel == 235:
# input("")
int_old_ram = list(map(int, oldram))
int_ram = list(map(int, ram))
difference = list()
for o, r in zip(int_old_ram, int_ram):
difference.append(r-o)
oldram = deepcopy(ram)
#print(f"player_x: {ram[25]},\tclock_m: {ram[104]},\tclock_s: {ram[105]},\tclock_ms: {ram[106]},\tscore: {ram[107]}")
print(f"player_x: {ram[25]},\tplayer_y: {y_ram_setting}")
#print(f"y_0: {ram[86]}, y_1: {ram[87]}, y_2: {ram[88]}, y_3: {ram[89]}, y_4: {ram[90]}, y_5: {ram[91]}, y_6: {ram[92]}, y_7: {ram[93]}, y_8: {ram[94]}")
#for i, r in enumerate(ram):
# print('{:03}:{:02x} '.format(i,r), end="")
# if i % 16 == 15: print("")
#print("")
#for i, r in enumerate(difference):
# string = '{:02}:{:03} '.format(i%100,r)
# if r != 0:
# print(color(string, fg='red'), end="")
# else:
# print(string, end="")
# if i % 16 == 15: print("")
print("Episode %d ended with score: %d" % (episode, total_reward))
input("")
with open('all_positions_v2.pickle', 'wb') as handle:
pickle.dump(ramDICT, handle, protocol=pickle.HIGHEST_PROTOCOL)
ale.reset_game()
Loading…
Cancel
Save