You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
3.5 KiB

3 months ago
  1. ---
  2. layout: "contents"
  3. title: Training Minigrid Environments
  4. firstpage:
  5. ---
  6. # Training Minigrid Environments
  7. The environments in the Minigrid library can be trained easily using [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/). In this tutorial we show how a PPO agent can be trained on the `MiniGrid-Empty-16x16-v0` environment.
  8. ## Create Custom Feature Extractor
  9. Although `StableBaselines3` is fully compatible with `Gymnasium`-based environments (which includes Minigrid), the default CNN architecture does not directly support the Minigrid observation space. Thus, to train an agent on Minigrid environments, we need to create a custom feature extractor. This can be done by creating a feature extractor class that inherits from `stable_baselines3.common.torch_layers.BaseFeaturesExtractor`
  10. ```python
  11. class MinigridFeaturesExtractor(BaseFeaturesExtractor):
  12. def __init__(self, observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False) -> None:
  13. super().__init__(observation_space, features_dim)
  14. n_input_channels = observation_space.shape[0]
  15. self.cnn = nn.Sequential(
  16. nn.Conv2d(n_input_channels, 16, (2, 2)),
  17. nn.ReLU(),
  18. nn.Conv2d(16, 32, (2, 2)),
  19. nn.ReLU(),
  20. nn.Conv2d(32, 64, (2, 2)),
  21. nn.ReLU(),
  22. nn.Flatten(),
  23. )
  24. # Compute shape by doing one forward pass
  25. with torch.no_grad():
  26. n_flatten = self.cnn(torch.as_tensor(observation_space.sample()[None]).float()).shape[1]
  27. self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
  28. def forward(self, observations: torch.Tensor) -> torch.Tensor:
  29. return self.linear(self.cnn(observations))
  30. ```
  31. This class is created based on the custom feature extractor [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#custom-feature-extractor:~:text=Custom%20Feature%20Extractor-,%EF%83%81,-If%20you%20want), the CNN architecture is copied from Lucas Willems' [rl-starter-files](https://github.com/lcswillems/rl-starter-files/blob/317da04a9a6fb26506bbd7f6c7c7e10fc0de86e0/model.py#L18).
  32. ## Train a PPO Agent
  33. The using the custom feature extractor, we can train a PPO agent on the `MiniGrid-Empty-16x16-v0` environment. The following code snippet shows how this can be done.
  34. ```python
  35. import minigrid
  36. from minigrid.wrappers import ImgObsWrapper
  37. from stable_baselines3 import PPO
  38. policy_kwargs = dict(
  39. features_extractor_class=MinigridFeaturesExtractor,
  40. features_extractor_kwargs=dict(features_dim=128),
  41. )
  42. env = gym.make("MiniGrid-Empty-16x16-v0", render_mode="rgb_array")
  43. env = ImgObsWrapper(env)
  44. model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
  45. model.learn(2e5)
  46. ```
  47. By default the observation of Minigrid environments are dictionaries. Since the `CnnPolicy` from StableBaseline3 by default takes in image observations, we need to wrap the environment using the `ImgObsWrapper` from the Minigrid library. This wrapper converts the dictionary observation to an image observation.
  48. ## Further Reading
  49. One can also pass dictionary observations to StableBaseline3 policies, for a walkthrough the process of doing so see [here](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations). An implementation utilizing this functionality can be found [here](https://github.com/BolunDai0216/MinigridMiniworldTransfer/blob/main/minigrid_gotoobj_train.py).