You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.9 KiB

  1. import time
  2. from collections import deque
  3. from typing import Dict, Tuple
  4. import gymnasium as gym
  5. import numpy as np
  6. import torch
  7. from torch import Tensor
  8. from sample_factory.algo.learning.learner import Learner
  9. from sample_factory.algo.sampling.batched_sampling import preprocess_actions
  10. from sample_factory.algo.utils.action_distributions import argmax_actions
  11. from sample_factory.algo.utils.env_info import extract_env_info
  12. from sample_factory.algo.utils.make_env import make_env_func_batched
  13. from sample_factory.algo.utils.misc import ExperimentStatus
  14. from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
  15. from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
  16. from sample_factory.cfg.arguments import load_from_checkpoint
  17. from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
  18. from sample_factory.model.actor_critic import create_actor_critic
  19. from sample_factory.model.model_utils import get_rnn_size
  20. from sample_factory.utils.attr_dict import AttrDict
  21. from sample_factory.utils.typing import Config, StatusCode
  22. from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
  23. from sf_examples.atari.train_atari import parse_atari_args, register_atari_components
  24. class SampleFactoryNNQueryWrapper:
  25. def setup(self):
  26. register_atari_components()
  27. cfg = parse_atari_args()
  28. actor_critic = create_actor_critic(cfg, gym.spaces.Dict({"obs": gym.spaces.Box(0, 255, (4, 84, 84), np.uint8)}), gym.spaces.Discrete(3)) # TODO
  29. actor_critic.eval()
  30. device = torch.device("cpu") # ("cpu" if cfg.device == "cpu" else "cuda")
  31. actor_critic.model_to_device(device)
  32. policy_id = 0 #cfg.policy_index
  33. #name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
  34. name_prefix = "best"
  35. checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
  36. checkpoint_dict = Learner.load_checkpoint(checkpoints, device) # torch.load(...)
  37. actor_critic.load_state_dict(checkpoint_dict["model"])
  38. rnn_states = torch.zeros([1, get_rnn_size(cfg)], dtype=torch.float32, device=device)
  39. self.rnn_states = rnn_states
  40. self.actor_critic = actor_critic
  41. def __init__(self):
  42. self.setup()
  43. def query(self, obs):
  44. with torch.no_grad():
  45. normalized_obs = prepare_and_normalize_obs(self.actor_critic, obs)
  46. policy_outputs = self.actor_critic(normalized_obs, self.rnn_states)
  47. # sample actions from the distribution by default
  48. actions = policy_outputs["actions"]
  49. action_distribution = self.actor_critic.action_distribution()
  50. actions = argmax_actions(action_distribution)
  51. if actions.ndim == 1:
  52. actions = unsqueeze_tensor(actions, dim=-1)
  53. rnn_states = policy_outputs["new_rnn_states"]
  54. return actions[0][0].item()