You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
8.5 KiB
202 lines
8.5 KiB
from __future__ import annotations
|
|
|
|
from typing import Any, Callable
|
|
|
|
from gymnasium import spaces
|
|
from gymnasium.utils import seeding
|
|
|
|
|
|
def check_if_no_duplicate(duplicate_list: list) -> bool:
|
|
"""Check if given list contains any duplicates"""
|
|
return len(set(duplicate_list)) == len(duplicate_list)
|
|
|
|
|
|
class MissionSpace(spaces.Space[str]):
|
|
r"""A space representing a mission for the Gym-Minigrid environments.
|
|
The space allows generating random mission strings constructed with an input placeholder list.
|
|
Example Usage::
|
|
>>> observation_space = MissionSpace(mission_func=lambda color: f"Get the {color} ball.",
|
|
... ordered_placeholders=[["green", "blue"]])
|
|
>>> _ = observation_space.seed(123)
|
|
>>> observation_space.sample()
|
|
'Get the green ball.'
|
|
>>> observation_space = MissionSpace(mission_func=lambda : "Get the ball.",
|
|
... ordered_placeholders=None)
|
|
>>> observation_space.sample()
|
|
'Get the ball.'
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
mission_func: Callable[..., str],
|
|
ordered_placeholders: list[list[str]] | None = None,
|
|
seed: int | seeding.RandomNumberGenerator | None = None,
|
|
):
|
|
r"""Constructor of :class:`MissionSpace` space.
|
|
|
|
Args:
|
|
mission_func (lambda _placeholders(str): _mission(str)): Function that generates a mission string from random placeholders.
|
|
ordered_placeholders (Optional["list[list[str]]"]): List of lists of placeholders ordered in placing order in the mission function mission_func.
|
|
seed: seed: The seed for sampling from the space.
|
|
"""
|
|
# Check that the ordered placeholders and mission function are well defined.
|
|
if ordered_placeholders is not None:
|
|
assert (
|
|
len(ordered_placeholders) == mission_func.__code__.co_argcount
|
|
), f"The number of placeholders {len(ordered_placeholders)} is different from the number of parameters in the mission function {mission_func.__code__.co_argcount}."
|
|
for placeholder_list in ordered_placeholders:
|
|
assert check_if_no_duplicate(
|
|
placeholder_list
|
|
), "Make sure that the placeholders don't have any duplicate values."
|
|
else:
|
|
assert (
|
|
mission_func.__code__.co_argcount == 0
|
|
), f"If the ordered placeholders are {ordered_placeholders}, the mission function shouldn't have any parameters."
|
|
|
|
self.ordered_placeholders = ordered_placeholders
|
|
self.mission_func = mission_func
|
|
|
|
super().__init__(dtype=str, seed=seed)
|
|
|
|
# Check that mission_func returns a string
|
|
sampled_mission = self.sample()
|
|
assert isinstance(
|
|
sampled_mission, str
|
|
), f"mission_func must return type str not {type(sampled_mission)}"
|
|
|
|
def sample(self) -> str:
|
|
"""Sample a random mission string."""
|
|
if self.ordered_placeholders is not None:
|
|
placeholders = []
|
|
for rand_var_list in self.ordered_placeholders:
|
|
idx = self.np_random.integers(0, len(rand_var_list))
|
|
|
|
placeholders.append(rand_var_list[idx])
|
|
|
|
return self.mission_func(*placeholders)
|
|
else:
|
|
return self.mission_func()
|
|
|
|
def contains(self, x: Any) -> bool:
|
|
"""Return boolean specifying if x is a valid member of this space."""
|
|
# Store a list of all the placeholders from self.ordered_placeholders that appear in x
|
|
if self.ordered_placeholders is not None:
|
|
check_placeholder_list = []
|
|
for placeholder_list in self.ordered_placeholders:
|
|
for placeholder in placeholder_list:
|
|
if placeholder in x:
|
|
check_placeholder_list.append(placeholder)
|
|
|
|
# Remove duplicates from the list
|
|
check_placeholder_list = list(set(check_placeholder_list))
|
|
|
|
start_id_placeholder = []
|
|
end_id_placeholder = []
|
|
# Get the starting and ending id of the identified placeholders with possible duplicates
|
|
new_check_placeholder_list = []
|
|
for placeholder in check_placeholder_list:
|
|
new_start_id_placeholder = [
|
|
i for i in range(len(x)) if x.startswith(placeholder, i)
|
|
]
|
|
new_check_placeholder_list += [placeholder] * len(
|
|
new_start_id_placeholder
|
|
)
|
|
end_id_placeholder += [
|
|
start_id + len(placeholder) - 1
|
|
for start_id in new_start_id_placeholder
|
|
]
|
|
start_id_placeholder += new_start_id_placeholder
|
|
|
|
# Order by starting id the placeholders
|
|
ordered_placeholder_list = sorted(
|
|
zip(
|
|
start_id_placeholder, end_id_placeholder, new_check_placeholder_list
|
|
)
|
|
)
|
|
|
|
# Check for repeated placeholders contained in each other
|
|
remove_placeholder_id = []
|
|
for i, placeholder_1 in enumerate(ordered_placeholder_list):
|
|
starting_id = i + 1
|
|
for j, placeholder_2 in enumerate(
|
|
ordered_placeholder_list[starting_id:]
|
|
):
|
|
# Check if place holder ids overlap and keep the longest
|
|
if max(placeholder_1[0], placeholder_2[0]) < min(
|
|
placeholder_1[1], placeholder_2[1]
|
|
):
|
|
remove_placeholder = min(
|
|
placeholder_1[2], placeholder_2[2], key=len
|
|
)
|
|
if remove_placeholder == placeholder_1[2]:
|
|
remove_placeholder_id.append(i)
|
|
else:
|
|
remove_placeholder_id.append(i + j + 1)
|
|
for id in remove_placeholder_id:
|
|
del ordered_placeholder_list[id]
|
|
|
|
final_placeholders = [
|
|
placeholder[2] for placeholder in ordered_placeholder_list
|
|
]
|
|
|
|
# Check that the identified final placeholders are in the same order as the original placeholders.
|
|
for orered_placeholder, final_placeholder in zip(
|
|
self.ordered_placeholders, final_placeholders
|
|
):
|
|
if final_placeholder in orered_placeholder:
|
|
continue
|
|
else:
|
|
return False
|
|
try:
|
|
mission_string_with_placeholders = self.mission_func(
|
|
*final_placeholders
|
|
)
|
|
except Exception as e:
|
|
print(
|
|
f"{x} is not contained in MissionSpace due to the following exception: {e}"
|
|
)
|
|
return False
|
|
|
|
return bool(mission_string_with_placeholders == x)
|
|
|
|
else:
|
|
return bool(self.mission_func() == x)
|
|
|
|
def __repr__(self) -> str:
|
|
"""Gives a string representation of this space."""
|
|
return f"MissionSpace({self.mission_func}, {self.ordered_placeholders})"
|
|
|
|
def __eq__(self, other) -> bool:
|
|
"""Check whether ``other`` is equivalent to this instance."""
|
|
if isinstance(other, MissionSpace):
|
|
|
|
# Check that place holder lists are the same
|
|
if self.ordered_placeholders is not None:
|
|
# Check length
|
|
if (
|
|
len(self.ordered_placeholders) == len(other.ordered_placeholders)
|
|
) and (
|
|
all(
|
|
set(i) == set(j)
|
|
for i, j in zip(
|
|
self.ordered_placeholders, other.ordered_placeholders
|
|
)
|
|
)
|
|
):
|
|
# Check mission string is the same with dummy space placeholders
|
|
test_placeholders = [""] * len(self.ordered_placeholders)
|
|
mission = self.mission_func(*test_placeholders)
|
|
other_mission = other.mission_func(*test_placeholders)
|
|
return mission == other_mission
|
|
else:
|
|
|
|
# Check that other is also None
|
|
if other.ordered_placeholders is None:
|
|
|
|
# Check mission string is the same
|
|
mission = self.mission_func()
|
|
other_mission = other.mission_func()
|
|
return mission == other_mission
|
|
|
|
# If none of the statements above return then False
|
|
return False
|