Thomas Knoll
1 year ago
3 changed files with 36 additions and 120 deletions
-
60examples/shields/rl/11_minigridrl.py
-
91examples/shields/rl/MaskEnvironments.py
-
5examples/shields/rl/Wrapper.py
@ -1,91 +0,0 @@ |
|||
import random |
|||
import minigrid |
|||
|
|||
import gymnasium as gym |
|||
import numpy as np |
|||
from gymnasium.spaces import Box, Dict, Discrete |
|||
from Wrapper import OneHotWrapper |
|||
|
|||
|
|||
class ParametricActionsMiniGridEnv(gym.Env): |
|||
"""Parametric action version of MiniGrid. |
|||
|
|||
""" |
|||
|
|||
def __init__(self, config): |
|||
|
|||
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |
|||
self.left_action_embed = np.random.randn(2) |
|||
self.right_action_embed = np.random.randn(2) |
|||
framestack = config.get("framestack", 4) |
|||
|
|||
# env = gym.make(name) |
|||
# env = minigrid.wrappers.ImgObsWrapper(env) |
|||
# env = OneHotWrapper(env, |
|||
# config.vector_index if hasattr(config, "vector_index") else 0, |
|||
# framestack=framestack |
|||
# ) |
|||
self.wrapped = gym.make(name) |
|||
# self.observation_space = Dict( |
|||
# { |
|||
# "action_mask": None, |
|||
# "avail_actions": None, |
|||
# "cart": self.wrapped.observation_space, |
|||
# } |
|||
# ) |
|||
print(F"Wrapped environment is {self.wrapped}") |
|||
self.step_count = 0 |
|||
self.action_space = self.wrapped.action_space |
|||
self.observation_space = self.wrapped.observation_space |
|||
|
|||
|
|||
def update_avail_actions(self): |
|||
self.action_assignments = np.array( |
|||
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32 |
|||
) |
|||
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) |
|||
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) |
|||
self.action_assignments[self.left_idx] = self.left_action_embed |
|||
self.action_assignments[self.right_idx] = self.right_action_embed |
|||
self.action_mask[self.left_idx] = 1 |
|||
self.action_mask[self.right_idx] = 1 |
|||
|
|||
def reset(self, *, seed=None, options=None): |
|||
self.update_avail_actions() |
|||
obs, infos = self.wrapped.reset() |
|||
return obs, infos |
|||
return { |
|||
"action_mask": self.action_mask, |
|||
"avail_action": self.action_assignments, |
|||
"cart": obs, |
|||
}, infos |
|||
|
|||
def step(self, action): |
|||
if action == self.left_idx: |
|||
actual_action = 0 |
|||
elif action == self.right_idx: |
|||
actual_action = 1 |
|||
else: |
|||
actual_action = 0 |
|||
# raise ValueError( |
|||
# "Chosen action was not one of the non-zero action embeddings", |
|||
# action, |
|||
# self.action_assignments, |
|||
# self.action_mask, |
|||
# self.left_idx, |
|||
# self.right_idx, |
|||
# ) |
|||
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) |
|||
self.update_avail_actions() |
|||
self.action_mask = self.action_mask.astype(np.int8) |
|||
print(F"Info is {info}") |
|||
info["Hello" : "Ich kenn mich nix aus"] |
|||
return orig_obs, rew, done, truncated, info |
|||
obs = { |
|||
"action_mask": self.action_mask, |
|||
"action_mask": self.action_assignments, |
|||
"cart": orig_obs, |
|||
} |
|||
return obs, rew, done, truncated, info |
|||
|
|||
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue