You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

202 lines
8.5 KiB

from __future__ import annotations
from typing import Any, Callable
from gymnasium import spaces
from gymnasium.utils import seeding
def check_if_no_duplicate(duplicate_list: list) -> bool:
"""Check if given list contains any duplicates"""
return len(set(duplicate_list)) == len(duplicate_list)
class MissionSpace(spaces.Space[str]):
r"""A space representing a mission for the Gym-Minigrid environments.
The space allows generating random mission strings constructed with an input placeholder list.
Example Usage::
>>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
... ordered_placeholders=[["green", "blue"]])
>>> _ = observation_space.seed(123)
>>> observation_space.sample()
'Get the green ball.'
>>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.",
... ordered_placeholders=None)
>>> observation_space.sample()
'Get the ball.'
"""
def __init__(
self,
mission_func: Callable[..., str],
ordered_placeholders: list[list[str]] | None = None,
seed: int | seeding.RandomNumberGenerator | None = None,
):
r"""Constructor of :class:`MissionSpace` space.
Args:
mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
seed: seed: The seed for sampling from the space.
"""
# Check that the ordered placeholders and mission function are well defined.
if ordered_placeholders is not None:
assert (
len(ordered_placeholders) == mission_func.__code__.co_argcount
), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
for placeholder_list in ordered_placeholders:
assert check_if_no_duplicate(
placeholder_list
), "Make sure that the placeholders don't have any duplicate values."
else:
assert (
mission_func.__code__.co_argcount == 0
), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
self.ordered_placeholders = ordered_placeholders
self.mission_func = mission_func
super().__init__(dtype=str, seed=seed)
# Check that mission_func returns a string
sampled_mission = self.sample()
assert isinstance(
sampled_mission, str
), f"mission_func must return type str not {type(sampled_mission)}"
def sample(self) -> str:
"""Sample a random mission string."""
if self.ordered_placeholders is not None:
placeholders = []
for rand_var_list in self.ordered_placeholders:
idx = self.np_random.integers(0, len(rand_var_list))
placeholders.append(rand_var_list[idx])
return self.mission_func(*placeholders)
else:
return self.mission_func()
def contains(self, x: Any) -> bool:
"""Return boolean specifying if x is a valid member of this space."""
# Store a list of all the placeholders from self.ordered_placeholders that appear in x
if self.ordered_placeholders is not None:
check_placeholder_list = []
for placeholder_list in self.ordered_placeholders:
for placeholder in placeholder_list:
if placeholder in x:
check_placeholder_list.append(placeholder)
# Remove duplicates from the list
check_placeholder_list = list(set(check_placeholder_list))
start_id_placeholder = []
end_id_placeholder = []
# Get the starting and ending id of the identified placeholders with possible duplicates
new_check_placeholder_list = []
for placeholder in check_placeholder_list:
new_start_id_placeholder = [
i for i in range(len(x)) if x.startswith(placeholder, i)
]
new_check_placeholder_list += [placeholder] * len(
new_start_id_placeholder
)
end_id_placeholder += [
start_id + len(placeholder) - 1
for start_id in new_start_id_placeholder
]
start_id_placeholder += new_start_id_placeholder
# Order by starting id the placeholders
ordered_placeholder_list = sorted(
zip(
start_id_placeholder, end_id_placeholder, new_check_placeholder_list
)
)
# Check for repeated placeholders contained in each other
remove_placeholder_id = []
for i, placeholder_1 in enumerate(ordered_placeholder_list):
starting_id = i + 1
for j, placeholder_2 in enumerate(
ordered_placeholder_list[starting_id:]
):
# Check if place holder ids overlap and keep the longest
if max(placeholder_1[0], placeholder_2[0]) < min(
placeholder_1[1], placeholder_2[1]
):
remove_placeholder = min(
placeholder_1[2], placeholder_2[2], key=len
)
if remove_placeholder == placeholder_1[2]:
remove_placeholder_id.append(i)
else:
remove_placeholder_id.append(i + j + 1)
for id in remove_placeholder_id:
del ordered_placeholder_list[id]
final_placeholders = [
placeholder[2] for placeholder in ordered_placeholder_list
]
# Check that the identified final placeholders are in the same order as the original placeholders.
for orered_placeholder, final_placeholder in zip(
self.ordered_placeholders, final_placeholders
):
if final_placeholder in orered_placeholder:
continue
else:
return False
try:
mission_string_with_placeholders = self.mission_func(
*final_placeholders
)
except Exception as e:
print(
f"{x} is not contained in MissionSpace due to the following exception: {e}"
)
return False
return bool(mission_string_with_placeholders == x)
else:
return bool(self.mission_func() == x)
def __repr__(self) -> str:
"""Gives a string representation of this space."""
return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
def __eq__(self, other) -> bool:
"""Check whether ``other`` is equivalent to this instance."""
if isinstance(other, MissionSpace):
# Check that place holder lists are the same
if self.ordered_placeholders is not None:
# Check length
if (
len(self.ordered_placeholders) == len(other.ordered_placeholders)
) and (
all(
set(i) == set(j)
for i, j in zip(
self.ordered_placeholders, other.ordered_placeholders
)
)
):
# Check mission string is the same with dummy space placeholders
test_placeholders = [""] * len(self.ordered_placeholders)
mission = self.mission_func(*test_placeholders)
other_mission = other.mission_func(*test_placeholders)
return mission == other_mission
else:
# Check that other is also None
if other.ordered_placeholders is None:
# Check mission string is the same
mission = self.mission_func()
other_mission = other.mission_func()
return mission == other_mission
# If none of the statements above return then False
return False