You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
1.5 KiB
58 lines
1.5 KiB
from __future__ import annotations
|
|
|
|
import gymnasium as gym
|
|
import pytest
|
|
|
|
from minigrid.utils.baby_ai_bot import BabyAIBot
|
|
|
|
# see discussion starting here: https://github.com/Farama-Foundation/Minigrid/pull/381#issuecomment-1646800992
|
|
broken_bonus_envs = {
|
|
"BabyAI-PutNextS5N2Carrying-v0",
|
|
"BabyAI-PutNextS6N3Carrying-v0",
|
|
"BabyAI-PutNextS7N4Carrying-v0",
|
|
"BabyAI-KeyInBox-v0",
|
|
}
|
|
|
|
# get all babyai envs (except the broken ones)
|
|
babyai_envs = []
|
|
for k_i in gym.envs.registry.keys():
|
|
if k_i.split("-")[0] == "BabyAI":
|
|
if k_i not in broken_bonus_envs:
|
|
babyai_envs.append(k_i)
|
|
|
|
|
|
@pytest.mark.parametrize("env_id", babyai_envs)
|
|
def test_bot(env_id):
|
|
"""
|
|
The BabyAI Bot should be able to solve all BabyAI environments,
|
|
allowing us therefore to generate demonstrations.
|
|
"""
|
|
# Use the parameter env_id to make the environment
|
|
env = gym.make(env_id)
|
|
# env = gym.make(env_id, render_mode="human") # for visual debugging
|
|
|
|
# reset env
|
|
curr_seed = 0
|
|
|
|
num_steps = 240
|
|
terminated = False
|
|
while not terminated:
|
|
env.reset(seed=curr_seed)
|
|
|
|
# create expert bot
|
|
expert = BabyAIBot(env)
|
|
|
|
last_action = None
|
|
for _step in range(num_steps):
|
|
action = expert.replan(last_action)
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
last_action = action
|
|
env.render()
|
|
|
|
if terminated:
|
|
break
|
|
|
|
# try again with a different seed
|
|
curr_seed += 1
|
|
|
|
env.close()
|