You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.9 KiB
69 lines
2.9 KiB
import time
|
|
from collections import deque
|
|
from typing import Dict, Tuple
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from sample_factory.algo.learning.learner import Learner
|
|
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
|
|
from sample_factory.algo.utils.action_distributions import argmax_actions
|
|
from sample_factory.algo.utils.env_info import extract_env_info
|
|
from sample_factory.algo.utils.make_env import make_env_func_batched
|
|
from sample_factory.algo.utils.misc import ExperimentStatus
|
|
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
|
|
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
|
|
from sample_factory.cfg.arguments import load_from_checkpoint
|
|
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
|
|
from sample_factory.model.actor_critic import create_actor_critic
|
|
from sample_factory.model.model_utils import get_rnn_size
|
|
from sample_factory.utils.attr_dict import AttrDict
|
|
from sample_factory.utils.typing import Config, StatusCode
|
|
from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
|
|
|
|
from sf_examples.atari.train_atari import parse_atari_args, register_atari_components
|
|
|
|
class SampleFactoryNNQueryWrapper:
|
|
def setup(self):
|
|
register_atari_components()
|
|
cfg = parse_atari_args()
|
|
actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO
|
|
actor_critic.eval()
|
|
|
|
device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda")
|
|
actor_critic.model_to_device(device)
|
|
|
|
policy_id = 0 #cfg.policy_index
|
|
#name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
|
|
name_prefix = "best"
|
|
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
|
|
checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...)
|
|
actor_critic.load_state_dict(checkpoint_dict["model"])
|
|
|
|
rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device)
|
|
|
|
self.rnn_states = rnn_states
|
|
self.actor_critic = actor_critic
|
|
|
|
def __init__(self):
|
|
self.setup()
|
|
|
|
|
|
def query(self, obs):
|
|
with torch.no_grad():
|
|
normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs)
|
|
policy_outputs = self.actor_critic(normalized_obs, self.rnn_states)
|
|
|
|
# sample actions from the distribution by default
|
|
actions = policy_outputs["actions"]
|
|
|
|
action_distribution = self.actor_critic.action_distribution()
|
|
actions = argmax_actions(action_distribution)
|
|
|
|
if actions.ndim == 1:
|
|
actions = unsqueeze_tensor(actions, dim=-1)
|
|
|
|
rnn_states = policy_outputs["new_rnn_states"]
|
|
return actions[0][0].item()
|