You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

91 lines
3.1 KiB

import random
import minigrid
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
from Wrapper import OneHotWrapper
class ParametricActionsMiniGridEnv(gym.Env):
"""Parametric action version of MiniGrid.
"""
def __init__(self, config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
self.left_action_embed = np.random.randn(2)
self.right_action_embed = np.random.randn(2)
framestack = config.get("framestack", 4)
# env = gym.make(name)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = OneHotWrapper(env,
# config.vector_index if hasattr(config, "vector_index") else 0,
# framestack=framestack
# )
self.wrapped = gym.make(name)
# self.observation_space = Dict(
# {
# "action_mask": None,
# "avail_actions": None,
# "cart": self.wrapped.observation_space,
# }
# )
print(F"Wrapped environment is {self.wrapped}")
self.step_count = 0
self.action_space = self.wrapped.action_space
self.observation_space = self.wrapped.observation_space
def update_avail_actions(self):
self.action_assignments = np.array(
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32
)
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed
self.action_assignments[self.right_idx] = self.right_action_embed
self.action_mask[self.left_idx] = 1
self.action_mask[self.right_idx] = 1
def reset(self, *, seed=None, options=None):
self.update_avail_actions()
obs, infos = self.wrapped.reset()
return obs, infos
return {
"action_mask": self.action_mask,
"avail_action": self.action_assignments,
"cart": obs,
}, infos
def step(self, action):
if action == self.left_idx:
actual_action = 0
elif action == self.right_idx:
actual_action = 1
else:
actual_action = 0
# raise ValueError(
# "Chosen action was not one of the non-zero action embeddings",
# action,
# self.action_assignments,
# self.action_mask,
# self.left_idx,
# self.right_idx,
# )
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
self.update_avail_actions()
self.action_mask = self.action_mask.astype(np.int8)
print(F"Info is {info}")
info["Hello" : "Ich kenn mich nix aus"]
return orig_obs, rew, done, truncated, info
obs = {
"action_mask": self.action_mask,
"action_mask": self.action_assignments,
"cart": orig_obs,
}
return obs, rew, done, truncated, info