Thomas Knoll
1 year ago
3 changed files with 36 additions and 120 deletions
-
56examples/shields/rl/11_minigridrl.py
-
91examples/shields/rl/MaskEnvironments.py
-
5examples/shields/rl/Wrapper.py
@ -1,91 +0,0 @@ |
|||||
import random |
|
||||
import minigrid |
|
||||
|
|
||||
import gymnasium as gym |
|
||||
import numpy as np |
|
||||
from gymnasium.spaces import Box, Dict, Discrete |
|
||||
from Wrapper import OneHotWrapper |
|
||||
|
|
||||
|
|
||||
class ParametricActionsMiniGridEnv(gym.Env): |
|
||||
"""Parametric action version of MiniGrid. |
|
||||
|
|
||||
""" |
|
||||
|
|
||||
def __init__(self, config): |
|
||||
|
|
||||
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
|
||||
self.left_action_embed = np.random.randn(2) |
|
||||
self.right_action_embed = np.random.randn(2) |
|
||||
framestack = config.get("framestack", 4) |
|
||||
|
|
||||
# env = gym.make(name) |
|
||||
# env = minigrid.wrappers.ImgObsWrapper(env) |
|
||||
# env = OneHotWrapper(env, |
|
||||
# config.vector_index if hasattr(config, "vector_index") else 0, |
|
||||
# framestack=framestack |
|
||||
# ) |
|
||||
self.wrapped = gym.make(name) |
|
||||
# self.observation_space = Dict( |
|
||||
# { |
|
||||
# "action_mask": None, |
|
||||
# "avail_actions": None, |
|
||||
# "cart": self.wrapped.observation_space, |
|
||||
# } |
|
||||
# ) |
|
||||
print(F"Wrapped environment is {self.wrapped}") |
|
||||
self.step_count = 0 |
|
||||
self.action_space = self.wrapped.action_space |
|
||||
self.observation_space = self.wrapped.observation_space |
|
||||
|
|
||||
|
|
||||
def update_avail_actions(self): |
|
||||
self.action_assignments = np.array( |
|
||||
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32 |
|
||||
) |
|
||||
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) |
|
||||
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) |
|
||||
self.action_assignments[self.left_idx] = self.left_action_embed |
|
||||
self.action_assignments[self.right_idx] = self.right_action_embed |
|
||||
self.action_mask[self.left_idx] = 1 |
|
||||
self.action_mask[self.right_idx] = 1 |
|
||||
|
|
||||
def reset(self, *, seed=None, options=None): |
|
||||
self.update_avail_actions() |
|
||||
obs, infos = self.wrapped.reset() |
|
||||
return obs, infos |
|
||||
return { |
|
||||
"action_mask": self.action_mask, |
|
||||
"avail_action": self.action_assignments, |
|
||||
"cart": obs, |
|
||||
}, infos |
|
||||
|
|
||||
def step(self, action): |
|
||||
if action == self.left_idx: |
|
||||
actual_action = 0 |
|
||||
elif action == self.right_idx: |
|
||||
actual_action = 1 |
|
||||
else: |
|
||||
actual_action = 0 |
|
||||
# raise ValueError( |
|
||||
# "Chosen action was not one of the non-zero action embeddings", |
|
||||
# action, |
|
||||
# self.action_assignments, |
|
||||
# self.action_mask, |
|
||||
# self.left_idx, |
|
||||
# self.right_idx, |
|
||||
# ) |
|
||||
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) |
|
||||
self.update_avail_actions() |
|
||||
self.action_mask = self.action_mask.astype(np.int8) |
|
||||
print(F"Info is {info}") |
|
||||
info["Hello" : "Ich kenn mich nix aus"] |
|
||||
return orig_obs, rew, done, truncated, info |
|
||||
obs = { |
|
||||
"action_mask": self.action_mask, |
|
||||
"action_mask": self.action_assignments, |
|
||||
"cart": orig_obs, |
|
||||
} |
|
||||
return obs, rew, done, truncated, info |
|
||||
|
|
||||
|
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue