You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
872 lines
27 KiB
872 lines
27 KiB
from __future__ import annotations
|
|
|
|
import math
|
|
import operator
|
|
from functools import reduce
|
|
from typing import Any
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium import logger, spaces
|
|
from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper
|
|
|
|
from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
|
|
from minigrid.core.world_object import Goal
|
|
|
|
|
|
class ReseedWrapper(Wrapper):
|
|
"""
|
|
Wrapper to always regenerate an environment with the same set of seeds.
|
|
This can be used to force an environment to always keep the same
|
|
configuration when reset.
|
|
|
|
Example:
|
|
>>> import minigrid
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import ReseedWrapper
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> _ = env.reset(seed=123)
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
[0, 6, 5, 0, 9, 2, 2, 1, 3, 1]
|
|
>>> env = ReseedWrapper(env, seeds=[0, 1], seed_idx=0)
|
|
>>> _, _ = env.reset()
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
[8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
|
|
>>> _, _ = env.reset()
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
[4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
|
|
>>> _, _ = env.reset()
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
[8, 6, 5, 2, 3, 0, 0, 0, 1, 8]
|
|
>>> _, _ = env.reset()
|
|
>>> [env.np_random.integers(10) for i in range(10)]
|
|
[4, 5, 7, 9, 0, 1, 8, 9, 2, 3]
|
|
"""
|
|
|
|
def __init__(self, env, seeds=(0,), seed_idx=0):
|
|
"""A wrapper that always regenerate an environment with the same set of seeds.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
seeds: A list of seed to be applied to the env
|
|
seed_idx: Index of the initial seed in seeds
|
|
"""
|
|
self.seeds = list(seeds)
|
|
self.seed_idx = seed_idx
|
|
super().__init__(env)
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[ObsType, dict[str, Any]]:
|
|
if seed is not None:
|
|
logger.warn(
|
|
"A seed has been passed to `ReseedWrapper.reset` which is ignored."
|
|
)
|
|
seed = self.seeds[self.seed_idx]
|
|
self.seed_idx = (self.seed_idx + 1) % len(self.seeds)
|
|
return self.env.reset(seed=seed, options=options)
|
|
|
|
|
|
class ActionBonus(gym.Wrapper):
|
|
"""
|
|
Wrapper which adds an exploration bonus.
|
|
This is a reward to encourage exploration of less
|
|
visited (state,action) pairs.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import ActionBonus
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> _, _ = env.reset(seed=0)
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
>>> print(reward)
|
|
0
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
>>> print(reward)
|
|
0
|
|
>>> env_bonus = ActionBonus(env)
|
|
>>> _, _ = env_bonus.reset(seed=0)
|
|
>>> _, reward, _, _, _ = env_bonus.step(1)
|
|
>>> print(reward)
|
|
1.0
|
|
>>> _, reward, _, _, _ = env_bonus.step(1)
|
|
>>> print(reward)
|
|
1.0
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
"""A wrapper that adds an exploration bonus to less visited (state,action) pairs.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
"""
|
|
super().__init__(env)
|
|
self.counts = {}
|
|
|
|
def step(self, action):
|
|
"""Steps through the environment with `action`."""
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
|
|
env = self.unwrapped
|
|
tup = (tuple(env.agent_pos), env.agent_dir, action)
|
|
|
|
# Get the count for this (s,a) pair
|
|
pre_count = 0
|
|
if tup in self.counts:
|
|
pre_count = self.counts[tup]
|
|
|
|
# Update the count for this (s,a) pair
|
|
new_count = pre_count + 1
|
|
self.counts[tup] = new_count
|
|
|
|
bonus = 1 / math.sqrt(new_count)
|
|
reward += bonus
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class PositionBonus(Wrapper):
|
|
"""
|
|
Adds an exploration bonus based on which positions
|
|
are visited on the grid.
|
|
|
|
Note:
|
|
This wrapper was previously called ``StateBonus``.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import PositionBonus
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> _, _ = env.reset(seed=0)
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
>>> print(reward)
|
|
0
|
|
>>> _, reward, _, _, _ = env.step(1)
|
|
>>> print(reward)
|
|
0
|
|
>>> env_bonus = PositionBonus(env)
|
|
>>> obs, _ = env_bonus.reset(seed=0)
|
|
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
|
|
>>> print(reward)
|
|
1.0
|
|
>>> obs, reward, terminated, truncated, info = env_bonus.step(1)
|
|
>>> print(reward)
|
|
0.7071067811865475
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
"""A wrapper that adds an exploration bonus to less visited positions.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
"""
|
|
super().__init__(env)
|
|
self.counts = {}
|
|
|
|
def step(self, action):
|
|
"""Steps through the environment with `action`."""
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
|
|
# Tuple based on which we index the counts
|
|
# We use the position after an update
|
|
env = self.unwrapped
|
|
tup = tuple(env.agent_pos)
|
|
|
|
# Get the count for this key
|
|
pre_count = 0
|
|
if tup in self.counts:
|
|
pre_count = self.counts[tup]
|
|
|
|
# Update the count for this key
|
|
new_count = pre_count + 1
|
|
self.counts[tup] = new_count
|
|
|
|
bonus = 1 / math.sqrt(new_count)
|
|
reward += bonus
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class ImgObsWrapper(ObservationWrapper):
|
|
"""
|
|
Use the image as the only observation output, no language/mission.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import ImgObsWrapper
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs.keys()
|
|
dict_keys(['image', 'direction', 'mission'])
|
|
>>> env = ImgObsWrapper(env)
|
|
>>> obs, _ = env.reset()
|
|
>>> obs.shape
|
|
(7, 7, 3)
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
"""A wrapper that makes image the only observation.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
"""
|
|
super().__init__(env)
|
|
self.observation_space = env.observation_space.spaces["image"]
|
|
|
|
def observation(self, obs):
|
|
return obs["image"]
|
|
|
|
|
|
class OneHotPartialObsWrapper(ObservationWrapper):
|
|
"""
|
|
Wrapper to get a one-hot encoding of a partially observable
|
|
agent view as observation.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import OneHotPartialObsWrapper
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs["image"][0, :, :]
|
|
array([[2, 5, 0],
|
|
[2, 5, 0],
|
|
[2, 5, 0],
|
|
[2, 5, 0],
|
|
[2, 5, 0],
|
|
[2, 5, 0],
|
|
[2, 5, 0]], dtype=uint8)
|
|
>>> env = OneHotPartialObsWrapper(env)
|
|
>>> obs, _ = env.reset()
|
|
>>> obs["image"][0, :, :]
|
|
array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
|
|
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0]],
|
|
dtype=uint8)
|
|
"""
|
|
|
|
def __init__(self, env, tile_size=8):
|
|
"""A wrapper that makes the image observation a one-hot encoding of a partially observable agent view.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
"""
|
|
super().__init__(env)
|
|
|
|
self.tile_size = tile_size
|
|
|
|
obs_shape = env.observation_space["image"].shape
|
|
|
|
# Number of bits per cell
|
|
num_bits = len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + len(STATE_TO_IDX)
|
|
|
|
new_image_space = spaces.Box(
|
|
low=0, high=255, shape=(obs_shape[0], obs_shape[1], num_bits), dtype="uint8"
|
|
)
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
img = obs["image"]
|
|
out = np.zeros(self.observation_space.spaces["image"].shape, dtype="uint8")
|
|
|
|
for i in range(img.shape[0]):
|
|
for j in range(img.shape[1]):
|
|
type = img[i, j, 0]
|
|
color = img[i, j, 1]
|
|
state = img[i, j, 2]
|
|
|
|
out[i, j, type] = 1
|
|
out[i, j, len(OBJECT_TO_IDX) + color] = 1
|
|
out[i, j, len(OBJECT_TO_IDX) + len(COLOR_TO_IDX) + state] = 1
|
|
|
|
return {**obs, "image": out}
|
|
|
|
|
|
class RGBImgObsWrapper(ObservationWrapper):
|
|
"""
|
|
Wrapper to use fully observable RGB image as observation,
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import RGBImgObsWrapper
|
|
>>> env = gym.make("MiniGrid-Empty-5x5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> plt.imshow(obs['image']) # doctest: +SKIP
|
|
![NoWrapper](../figures/lavacrossing_NoWrapper.png)
|
|
>>> env = RGBImgObsWrapper(env)
|
|
>>> obs, _ = env.reset()
|
|
>>> plt.imshow(obs['image']) # doctest: +SKIP
|
|
![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
|
|
"""
|
|
|
|
def __init__(self, env, tile_size=8):
|
|
super().__init__(env)
|
|
|
|
self.tile_size = tile_size
|
|
|
|
new_image_space = spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(
|
|
self.unwrapped.width * tile_size,
|
|
self.unwrapped.height * tile_size,
|
|
3,
|
|
),
|
|
dtype="uint8",
|
|
)
|
|
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
rgb_img = self.get_frame(
|
|
highlight=self.unwrapped.highlight, tile_size=self.tile_size
|
|
)
|
|
|
|
return {**obs, "image": rgb_img}
|
|
|
|
|
|
class RGBImgPartialObsWrapper(ObservationWrapper):
|
|
"""
|
|
Wrapper to use partially observable RGB image as observation.
|
|
This can be used to have the agent to solve the gridworld in pixel space.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import RGBImgObsWrapper, RGBImgPartialObsWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
![NoWrapper](../figures/lavacrossing_NoWrapper.png)
|
|
>>> env_obs = RGBImgObsWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
![RGBImgObsWrapper](../figures/lavacrossing_RGBImgObsWrapper.png)
|
|
>>> env_obs = RGBImgPartialObsWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> plt.imshow(obs["image"]) # doctest: +SKIP
|
|
![RGBImgPartialObsWrapper](../figures/lavacrossing_RGBImgPartialObsWrapper.png)
|
|
"""
|
|
|
|
def __init__(self, env, tile_size=8):
|
|
super().__init__(env)
|
|
|
|
# Rendering attributes for observations
|
|
self.tile_size = tile_size
|
|
|
|
obs_shape = env.observation_space.spaces["image"].shape
|
|
new_image_space = spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(obs_shape[0] * tile_size, obs_shape[1] * tile_size, 3),
|
|
dtype="uint8",
|
|
)
|
|
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
rgb_img_partial = self.get_frame(tile_size=self.tile_size, agent_pov=True)
|
|
|
|
return {**obs, "image": rgb_img_partial}
|
|
|
|
|
|
class FullyObsWrapper(ObservationWrapper):
|
|
"""
|
|
Fully observable gridworld using a compact grid encoding instead of the agent view.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import FullyObsWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs['image'].shape
|
|
(7, 7, 3)
|
|
>>> env_obs = FullyObsWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs['image'].shape
|
|
(11, 11, 3)
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
|
|
new_image_space = spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
dtype="uint8",
|
|
)
|
|
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
full_grid = env.grid.encode()
|
|
full_grid[env.agent_pos[0]][env.agent_pos[1]] = np.array(
|
|
[OBJECT_TO_IDX["agent"], COLOR_TO_IDX["red"], env.agent_dir]
|
|
)
|
|
|
|
return {**obs, "image": full_grid}
|
|
|
|
|
|
class DictObservationSpaceWrapper(ObservationWrapper):
|
|
"""
|
|
Transforms the observation space (that has a textual component) to a fully numerical observation space,
|
|
where the textual instructions are replaced by arrays representing the indices of each word in a fixed vocabulary.
|
|
|
|
This wrapper is not applicable to BabyAI environments, given that these have their own language component.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import DictObservationSpaceWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs['mission']
|
|
'avoid the lava and get to the green goal square'
|
|
>>> env_obs = DictObservationSpaceWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs['mission'][:10]
|
|
[19, 31, 17, 36, 20, 38, 31, 2, 15, 35]
|
|
"""
|
|
|
|
def __init__(self, env, max_words_in_mission=50, word_dict=None):
|
|
"""
|
|
max_words_in_mission is the length of the array to represent a mission, value 0 for missing words
|
|
word_dict is a dictionary of words to use (keys=words, values=indices from 1 to < max_words_in_mission),
|
|
if None, use the Minigrid language
|
|
"""
|
|
super().__init__(env)
|
|
|
|
if word_dict is None:
|
|
word_dict = self.get_minigrid_words()
|
|
|
|
self.max_words_in_mission = max_words_in_mission
|
|
self.word_dict = word_dict
|
|
|
|
self.observation_space = spaces.Dict(
|
|
{
|
|
"image": env.observation_space["image"],
|
|
"direction": spaces.Discrete(4),
|
|
"mission": spaces.MultiDiscrete(
|
|
[len(self.word_dict.keys())] * max_words_in_mission
|
|
),
|
|
}
|
|
)
|
|
|
|
@staticmethod
|
|
def get_minigrid_words():
|
|
colors = ["red", "green", "blue", "yellow", "purple", "grey"]
|
|
objects = [
|
|
"unseen",
|
|
"empty",
|
|
"wall",
|
|
"floor",
|
|
"box",
|
|
"key",
|
|
"ball",
|
|
"door",
|
|
"goal",
|
|
"agent",
|
|
"lava",
|
|
]
|
|
|
|
verbs = [
|
|
"pick",
|
|
"avoid",
|
|
"get",
|
|
"find",
|
|
"put",
|
|
"use",
|
|
"open",
|
|
"go",
|
|
"fetch",
|
|
"reach",
|
|
"unlock",
|
|
"traverse",
|
|
]
|
|
|
|
extra_words = [
|
|
"up",
|
|
"the",
|
|
"a",
|
|
"at",
|
|
",",
|
|
"square",
|
|
"and",
|
|
"then",
|
|
"to",
|
|
"of",
|
|
"rooms",
|
|
"near",
|
|
"opening",
|
|
"must",
|
|
"you",
|
|
"matching",
|
|
"end",
|
|
"hallway",
|
|
"object",
|
|
"from",
|
|
"room",
|
|
]
|
|
|
|
all_words = colors + objects + verbs + extra_words
|
|
assert len(all_words) == len(set(all_words))
|
|
return {word: i for i, word in enumerate(all_words)}
|
|
|
|
def string_to_indices(self, string, offset=1):
|
|
"""
|
|
Convert a string to a list of indices.
|
|
"""
|
|
indices = []
|
|
# adding space before and after commas
|
|
string = string.replace(",", " , ")
|
|
for word in string.split():
|
|
if word in self.word_dict.keys():
|
|
indices.append(self.word_dict[word] + offset)
|
|
else:
|
|
raise ValueError(f"Unknown word: {word}")
|
|
return indices
|
|
|
|
def observation(self, obs):
|
|
obs["mission"] = self.string_to_indices(obs["mission"])
|
|
assert len(obs["mission"]) < self.max_words_in_mission
|
|
obs["mission"] += [0] * (self.max_words_in_mission - len(obs["mission"]))
|
|
|
|
return obs
|
|
|
|
|
|
class FlatObsWrapper(ObservationWrapper):
|
|
"""
|
|
Encode mission strings using a one-hot scheme,
|
|
and combine these with observed images into one flat array.
|
|
|
|
This wrapper is not applicable to BabyAI environments, given that these have their own language component.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import FlatObsWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> env_obs = FlatObsWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs.shape
|
|
(2835,)
|
|
"""
|
|
|
|
def __init__(self, env, maxStrLen=96):
|
|
super().__init__(env)
|
|
|
|
self.maxStrLen = maxStrLen
|
|
self.numCharCodes = 28
|
|
|
|
imgSpace = env.observation_space.spaces["image"]
|
|
imgSize = reduce(operator.mul, imgSpace.shape, 1)
|
|
|
|
self.observation_space = spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(imgSize + self.numCharCodes * self.maxStrLen,),
|
|
dtype="uint8",
|
|
)
|
|
|
|
self.cachedStr: str = None
|
|
|
|
def observation(self, obs):
|
|
image = obs["image"]
|
|
mission = obs["mission"]
|
|
|
|
# Cache the last-encoded mission string
|
|
if mission != self.cachedStr:
|
|
assert (
|
|
len(mission) <= self.maxStrLen
|
|
), f"mission string too long ({len(mission)} chars)"
|
|
mission = mission.lower()
|
|
|
|
strArray = np.zeros(
|
|
shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
|
|
)
|
|
|
|
for idx, ch in enumerate(mission):
|
|
if ch >= "a" and ch <= "z":
|
|
chNo = ord(ch) - ord("a")
|
|
elif ch == " ":
|
|
chNo = ord("z") - ord("a") + 1
|
|
elif ch == ",":
|
|
chNo = ord("z") - ord("a") + 2
|
|
else:
|
|
raise ValueError(
|
|
f"Character {ch} is not available in mission string."
|
|
)
|
|
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
|
|
strArray[idx, chNo] = 1
|
|
|
|
self.cachedStr = mission
|
|
self.cachedArray = strArray
|
|
|
|
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
|
|
|
|
return obs
|
|
|
|
|
|
class ViewSizeWrapper(ObservationWrapper):
|
|
"""
|
|
Wrapper to customize the agent field of view size.
|
|
This cannot be used with fully observable wrappers.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import ViewSizeWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs['image'].shape
|
|
(7, 7, 3)
|
|
>>> env_obs = ViewSizeWrapper(env, agent_view_size=5)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs['image'].shape
|
|
(5, 5, 3)
|
|
"""
|
|
|
|
def __init__(self, env, agent_view_size=7):
|
|
super().__init__(env)
|
|
|
|
assert agent_view_size % 2 == 1
|
|
assert agent_view_size >= 3
|
|
|
|
self.agent_view_size = agent_view_size
|
|
|
|
# Compute observation space with specified view size
|
|
new_image_space = gym.spaces.Box(
|
|
low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8"
|
|
)
|
|
|
|
# Override the environment's observation spaceexit
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
env = self.unwrapped
|
|
|
|
grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
|
|
|
|
# Encode the partially observable view into a numpy array
|
|
image = grid.encode(vis_mask)
|
|
|
|
return {**obs, "image": image}
|
|
|
|
|
|
class DirectionObsWrapper(ObservationWrapper):
|
|
"""
|
|
Provides the slope/angular direction to the goal with the observations as modeled by (y2 - y2 )/( x2 - x1)
|
|
type = {slope , angle}
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> from minigrid.wrappers import DirectionObsWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> env_obs = DirectionObsWrapper(env, type="slope")
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs['goal_direction']
|
|
1.0
|
|
"""
|
|
|
|
def __init__(self, env, type="slope"):
|
|
super().__init__(env)
|
|
self.goal_position: tuple = None
|
|
self.type = type
|
|
|
|
def reset(
|
|
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
|
) -> tuple[ObsType, dict[str, Any]]:
|
|
obs, info = self.env.reset()
|
|
|
|
if not self.goal_position:
|
|
self.goal_position = [
|
|
x for x, y in enumerate(self.grid.grid) if isinstance(y, Goal)
|
|
]
|
|
# in case there are multiple goals , needs to be handled for other env types
|
|
if len(self.goal_position) >= 1:
|
|
self.goal_position = (
|
|
int(self.goal_position[0] / self.height),
|
|
self.goal_position[0] % self.width,
|
|
)
|
|
|
|
return self.observation(obs), info
|
|
|
|
def observation(self, obs):
|
|
slope = np.divide(
|
|
self.goal_position[1] - self.agent_pos[1],
|
|
self.goal_position[0] - self.agent_pos[0],
|
|
)
|
|
|
|
if self.type == "angle":
|
|
obs["goal_direction"] = np.arctan(slope)
|
|
else:
|
|
obs["goal_direction"] = slope
|
|
|
|
return obs
|
|
|
|
|
|
class SymbolicObsWrapper(ObservationWrapper):
|
|
"""
|
|
Fully observable grid with a symbolic state representation.
|
|
The symbol is a triple of (X, Y, IDX), where X and Y are
|
|
the coordinates on the grid, and IDX is the id of the object.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import SymbolicObsWrapper
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS11N5-v0")
|
|
>>> obs, _ = env.reset()
|
|
>>> obs['image'].shape
|
|
(7, 7, 3)
|
|
>>> env_obs = SymbolicObsWrapper(env)
|
|
>>> obs, _ = env_obs.reset()
|
|
>>> obs['image'].shape
|
|
(11, 11, 3)
|
|
"""
|
|
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
|
|
new_image_space = spaces.Box(
|
|
low=0,
|
|
high=max(OBJECT_TO_IDX.values()),
|
|
shape=(self.env.width, self.env.height, 3), # number of cells
|
|
dtype="uint8",
|
|
)
|
|
self.observation_space = spaces.Dict(
|
|
{**self.observation_space.spaces, "image": new_image_space}
|
|
)
|
|
|
|
def observation(self, obs):
|
|
objects = np.array(
|
|
[OBJECT_TO_IDX[o.type] if o is not None else -1 for o in self.grid.grid]
|
|
)
|
|
agent_pos = self.env.agent_pos
|
|
ncol, nrow = self.width, self.height
|
|
grid = np.mgrid[:ncol, :nrow]
|
|
_objects = np.transpose(objects.reshape(1, nrow, ncol), (0, 2, 1))
|
|
|
|
grid = np.concatenate([grid, _objects])
|
|
grid = np.transpose(grid, (1, 2, 0))
|
|
grid[agent_pos[0], agent_pos[1], 2] = OBJECT_TO_IDX["agent"]
|
|
obs["image"] = grid
|
|
|
|
return obs
|
|
|
|
|
|
class StochasticActionWrapper(ActionWrapper):
|
|
"""
|
|
Add stochasticity to the actions
|
|
|
|
If a random action is provided, it is returned with probability `1 - prob`.
|
|
Else, a random action is sampled from the action space.
|
|
"""
|
|
|
|
def __init__(self, env=None, prob=0.9, random_action=None):
|
|
super().__init__(env)
|
|
self.prob = prob
|
|
self.random_action = random_action
|
|
|
|
def action(self, action):
|
|
""" """
|
|
if np.random.uniform() < self.prob:
|
|
return action
|
|
else:
|
|
if self.random_action is None:
|
|
return self.np_random.integers(0, high=6)
|
|
else:
|
|
return self.random_action
|
|
|
|
|
|
class NoDeath(Wrapper):
|
|
"""
|
|
Wrapper to prevent death in specific cells (e.g., lava cells).
|
|
Instead of dying, the agent will receive a negative reward.
|
|
|
|
Example:
|
|
>>> import gymnasium as gym
|
|
>>> from minigrid.wrappers import NoDeath
|
|
>>>
|
|
>>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
|
|
>>> _, _ = env.reset(seed=2)
|
|
>>> _, _, _, _, _ = env.step(1)
|
|
>>> _, reward, term, *_ = env.step(2)
|
|
>>> reward, term
|
|
(0, True)
|
|
>>>
|
|
>>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
|
|
>>> _, _ = env.reset(seed=2)
|
|
>>> _, _, _, _, _ = env.step(1)
|
|
>>> _, reward, term, *_ = env.step(2)
|
|
>>> reward, term
|
|
(-1.0, False)
|
|
>>>
|
|
>>>
|
|
>>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
|
|
>>> _, _ = env.reset(seed=2)
|
|
>>> _, reward, term, *_ = env.step(2)
|
|
>>> reward, term
|
|
(-1, True)
|
|
>>>
|
|
>>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
|
|
>>> _, _ = env.reset(seed=2)
|
|
>>> _, reward, term, *_ = env.step(2)
|
|
>>> reward, term
|
|
(-2.0, False)
|
|
"""
|
|
|
|
def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
|
|
"""A wrapper to prevent death in specific cells.
|
|
|
|
Args:
|
|
env: The environment to apply the wrapper
|
|
no_death_types: List of strings to identify death cells
|
|
death_cost: The negative reward received in death cells
|
|
|
|
"""
|
|
assert "goal" not in no_death_types, "goal cannot be a death cell"
|
|
|
|
super().__init__(env)
|
|
self.death_cost = death_cost
|
|
self.no_death_types = no_death_types
|
|
|
|
def step(self, action):
|
|
# In Dynamic-Obstacles, obstacles move after the agent moves,
|
|
# so we need to check for collision before self.env.step()
|
|
front_cell = self.grid.get(*self.front_pos)
|
|
going_to_death = (
|
|
action == self.actions.forward
|
|
and front_cell is not None
|
|
and front_cell.type in self.no_death_types
|
|
)
|
|
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
|
|
# We also check if the agent stays in death cells (e.g., lava)
|
|
# without moving
|
|
current_cell = self.grid.get(*self.agent_pos)
|
|
in_death = current_cell is not None and current_cell.type in self.no_death_types
|
|
|
|
if terminated and (going_to_death or in_death):
|
|
terminated = False
|
|
reward += self.death_cost
|
|
|
|
return obs, reward, terminated, truncated, info
|