You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.8 KiB
130 lines
4.8 KiB
from __future__ import annotations
|
|
|
|
from minigrid.core.grid import Grid
|
|
from minigrid.core.mission import MissionSpace
|
|
from minigrid.core.world_object import Goal, Lava, SlipperyNorth, SlipperyEast, SlipperySouth, SlipperyWest, Ball
|
|
from minigrid.minigrid_env import MiniGridEnv, is_slippery
|
|
from minigrid.core.tasks import GoTo, DoNothing, PickUpObject, PlaceObject
|
|
|
|
import numpy as np
|
|
|
|
class AdversaryEnv(MiniGridEnv):
|
|
|
|
"""
|
|
## Description
|
|
|
|
"""
|
|
|
|
def __init__(self, width=7, height=6, generate_wall=True, generate_lava=False, generate_slippery=False ,max_steps: int | None = None, **kwargs):
|
|
if max_steps is None:
|
|
max_steps = 10 * (width * height)**2
|
|
mission_space = MissionSpace(mission_func=self._gen_mission)
|
|
self.collision_penalty = -1
|
|
super().__init__(
|
|
mission_space=mission_space, width=width, height=height, max_steps=max_steps, **kwargs
|
|
)
|
|
|
|
@staticmethod
|
|
def _gen_mission():
|
|
return "Finish your task while avoiding the adversaries"
|
|
|
|
def _gen_grid(self, width, height):
|
|
self.grid = Grid(width, height)
|
|
self.grid.wall_rect(0, 0, width, height)
|
|
|
|
|
|
def step(self, action):
|
|
delete_list = list()
|
|
for position, box in self.background_tiles.items():
|
|
if self.grid.get(*position) is None:
|
|
self.grid.set(*position, box)
|
|
self.grid.set_background(*position, None)
|
|
delete_list.append(tuple(position))
|
|
for position in delete_list:
|
|
del self.background_tiles[position]
|
|
|
|
obs, reward, terminated, truncated, info = super().step(action)
|
|
|
|
agent_pos = self.agent_pos
|
|
adv_penalty = 0
|
|
|
|
if not terminated:
|
|
for adversary in self.adversaries.values():
|
|
collided = self.move_adversary(adversary, agent_pos)
|
|
self.trajectory.append((adversary.color, adversary.adversary_pos, adversary.adversary_dir))
|
|
if collided:
|
|
terminated = True
|
|
info["collision"] = True
|
|
try:
|
|
reward = self.collision_penalty
|
|
except e:
|
|
reward = -1
|
|
|
|
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
def move_adversary(self, adversary, agent_pos):
|
|
# fetch current location and forward location
|
|
cur_pos = adversary.adversary_pos
|
|
current_cell = self.grid.get(*adversary.adversary_pos)
|
|
fwd_pos = cur_pos + adversary.dir_vec()
|
|
fwd_cell = self.grid.get(*fwd_pos)
|
|
collision = False
|
|
need_position_update = False
|
|
|
|
action = adversary.get_action(self)
|
|
if action == self.actions.forward and is_slippery(current_cell):
|
|
probabilities = current_cell.get_probabilities(adversary.adversary_dir)
|
|
possible_fwd_pos, prob = self.get_neighbours_prob(adversary.adversary_pos, probabilities)
|
|
fwd_pos_index = np.random.choice(len(possible_fwd_pos), 1, p=prob)
|
|
fwd_pos = possible_fwd_pos[fwd_pos_index[0]]
|
|
fwd_cell = self.grid.get(*fwd_pos)
|
|
need_position_update = True
|
|
|
|
if action == self.actions.left:
|
|
adversary.adversary_dir -= 1
|
|
if adversary.adversary_dir < 0:
|
|
adversary.adversary_dir += 4
|
|
|
|
# Rotate right
|
|
elif action == self.actions.right:
|
|
adversary.adversary_dir = (adversary.adversary_dir + 1) % 4
|
|
|
|
# Move forward
|
|
elif action == self.actions.forward:
|
|
if fwd_pos[0] == agent_pos[0] and fwd_pos[1] == agent_pos[1]:
|
|
collision = True
|
|
if fwd_cell is None or fwd_cell.can_overlap():
|
|
adversary.adversary_pos = tuple(fwd_pos)
|
|
|
|
# Pick up an object
|
|
elif action == self.actions.pickup:
|
|
if fwd_cell and fwd_cell.can_pickup():
|
|
if adversary.carrying is None:
|
|
adversary.carrying = fwd_cell
|
|
adversary.carrying.cur_pos = np.array([-1, -1])
|
|
self.grid.set(fwd_pos[0], fwd_pos[1], None)
|
|
|
|
# Drop an object
|
|
elif action == self.actions.drop:
|
|
if not fwd_cell and adversary.carrying:
|
|
self.grid.set(fwd_pos[0], fwd_pos[1], adversary.carrying)
|
|
adversary.carrying.cur_pos = fwd_pos
|
|
adversary.carrying = None
|
|
|
|
# Toggle/activate an object
|
|
elif action == self.actions.toggle:
|
|
if fwd_cell:
|
|
fwd_cell.toggle(self, fwd_pos)
|
|
|
|
# Done action (not used by default)
|
|
elif action == self.actions.done:
|
|
pass
|
|
|
|
else:
|
|
raise ValueError(f"Unknown action: {action}")
|
|
|
|
if need_position_update and (fwd_cell is None or fwd_cell.can_overlap()):
|
|
adversary.adversary_pos = tuple(fwd_pos)
|
|
|
|
return collision
|