You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.3 KiB
73 lines
2.3 KiB
from __future__ import annotations
|
|
import math
|
|
|
|
from numpy.random import default_rng
|
|
|
|
from minigrid.core.constants import COLOR_NAMES
|
|
from minigrid.core.grid import Grid
|
|
from minigrid.core.mission import MissionSpace
|
|
from minigrid.envs.lavaslippery import LavaSlipperyEnv
|
|
from minigrid.core.world_object import (
|
|
Slippery,
|
|
Lava,
|
|
Goal,
|
|
Wall
|
|
)
|
|
|
|
from minigrid.minigrid_env import MiniGridEnv, is_slippery
|
|
|
|
import numpy as np
|
|
|
|
from loguru import logger
|
|
|
|
class MCHW11Env(LavaSlipperyEnv):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def _gen_grid(self, width, height):
|
|
super()._gen_grid(width, height)
|
|
|
|
self.probability_intended = 0.91
|
|
|
|
slippery = Slippery(probability_intended=self.probability_intended)
|
|
for x in range(1, self.width - 1):
|
|
for y in range(1, self.height - 1):
|
|
self.grid.set(x,y, slippery)
|
|
|
|
self.disable_random_start()
|
|
|
|
self.put_obj(Lava(),4,3)
|
|
self.put_obj(Wall(),9,3)
|
|
|
|
|
|
|
|
agent_dir = 1 # We do not consider envs where the robot can turn
|
|
self.probability_turn_intended = 0.0
|
|
self.place_agent(agent_pos=np.array((1,1)), agent_dir=agent_dir, spawn_on_slippery=True)
|
|
self.place_goal(np.array((width - 2, height - 2)))
|
|
if self.dense_rewards: self.run_bfs()
|
|
|
|
def place_agent(self, spawn_on_slippery=False, agent_pos=None, agent_dir=0):
|
|
max_tries = 10_000
|
|
num_tries = 0
|
|
|
|
if self.randomize_start == True:
|
|
while True:
|
|
num_tries += 1
|
|
if num_tries > max_tries:
|
|
raise RecursionError("rejection sampling failed in place_agent")
|
|
x = np.random.randint(0, self.width)
|
|
y = np.random.randint(0, 3)
|
|
|
|
cell = self.grid.get(*(x,y))
|
|
if cell is None or (cell.can_overlap() and not isinstance(cell, Lava) and not isinstance(cell, Goal) and (spawn_on_slippery or not is_slippery(cell))):
|
|
self.agent_pos = np.array((x, y))
|
|
self.agent_dir = np.random.randint(0, 4)
|
|
break
|
|
elif agent_dir is None:
|
|
#self.agent_pos = np.array((1, 1))
|
|
self.agent_dir = 0
|
|
else:
|
|
self.agent_pos = agent_pos
|
|
self.agent_dir = agent_dir
|
|
|