You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

43 lines
1.8 KiB

from __future__ import annotations
from minigrid.core.grid import Grid
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
from minigrid.envs.adversaries_base import AdversaryEnv
from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject
import numpy as np
class OscillatingAdversaries(AdversaryEnv):
def __init__(self, width=8, height=8, max_steps: int | None = None, **kwargs):
if max_steps is None:
max_steps = 200
super().__init__(
width=width, height=height, max_steps=max_steps, **kwargs
)
def _gen_grid(self, width, height):
assert width >= 8 and height >= 8
self.grid = Grid(width, height)
self.grid.wall_rect(0, 0, width, height)
goal_pos = np.array((int(width/2) - 1, height - 2))
self.put_obj(Goal(), *goal_pos)
self.adversaries = {}
self.agent_pos = np.array((int(width/2), 1))
self.agent_dir = 2
yellow_adv = self.add_adversary(1, 3, "yellow", direction=0, tasks=[GoTo((3, 3)),
GoTo((3, 1)),
GoTo((1, 3))], repeating=True)
green_adv = self.add_adversary(width - 2, 5, "green", direction=3, tasks=[GoTo((width - 2, 3)),
GoTo((width - 4, 3)),
GoTo((width - 4, 5))], repeating=True)
def step(self, action):
obs, reward, terminated, truncated, info = super().step(action)
return obs, reward, terminated, truncated, info