import time from collections import deque from typing import Dict, Tuple import gymnasium as gym import numpy as np import torch from torch import Tensor from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info from sample_factory.algo.utils.make_env import make_env_func_batched from sample_factory.algo.utils.misc import ExperimentStatus from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor from sample_factory.cfg.arguments import load_from_checkpoint from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size from sample_factory.utils.attr_dict import AttrDict from sample_factory.utils.typing import Config, StatusCode from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log from sf_examples.atari.train_atari import parse_atari_args, register_atari_components class SampleFactoryNNQueryWrapper: def setup(self): register_atari_components() cfg = parse_atari_args() actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO actor_critic.eval() device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda") actor_critic.model_to_device(device) policy_id = 0 #cfg.policy_index #name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind] name_prefix = "best" checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*") checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...) actor_critic.load_state_dict(checkpoint_dict["model"]) rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device) self.rnn_states = rnn_states self.actor_critic = actor_critic def __init__(self): self.setup() def query(self, obs): with torch.no_grad(): normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs) policy_outputs = self.actor_critic(normalized_obs, self.rnn_states) # sample actions from the distribution by default actions = policy_outputs["actions"] action_distribution = self.actor_critic.action_distribution() actions = argmax_actions(action_distribution) if actions.ndim == 1: actions = unsqueeze_tensor(actions, dim=-1) rnn_states = policy_outputs["new_rnn_states"] return actions[0][0].item()