You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
70 lines
2.4 KiB
70 lines
2.4 KiB
import gymnasium as gym
|
|
import numpy as np
|
|
import random
|
|
|
|
from utils import MiniGridShieldHandler, create_shield_query
|
|
|
|
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
|
|
def __init__(self,
|
|
env,
|
|
shield_creator : MiniGridShieldHandler,
|
|
shield_query_creator,
|
|
create_shield_at_reset = True,
|
|
mask_actions=True,
|
|
):
|
|
super(MiniGridSbShieldingWrapper, self).__init__(env)
|
|
self.max_available_actions = env.action_space.n
|
|
self.observation_space = env.observation_space.spaces["image"]
|
|
|
|
self.shield_creator = shield_creator
|
|
self.mask_actions = mask_actions
|
|
self.shield_query_creator = shield_query_creator
|
|
|
|
def create_action_mask(self):
|
|
if not self.mask_actions:
|
|
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
|
|
|
|
cur_pos_str = self.shield_query_creator(self.env)
|
|
|
|
allowed_actions = []
|
|
|
|
# Create the mask
|
|
# If shield restricts actions, mask only valid actions with 1.0
|
|
# else set all actions valid
|
|
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
|
|
|
|
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
|
|
allowed_actions = self.shield[cur_pos_str]
|
|
for allowed_action in allowed_actions:
|
|
index = get_action_index_mapping(allowed_action.labels)
|
|
if index is None:
|
|
assert(False)
|
|
|
|
mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
|
|
else:
|
|
for index, x in enumerate(mask):
|
|
mask[index] = 1.0
|
|
|
|
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
|
|
|
|
|
|
if front_tile and front_tile.type == "door":
|
|
mask[Actions.toggle] = 1.0
|
|
|
|
return mask
|
|
|
|
|
|
def reset(self, *, seed=None, options=None):
|
|
obs, infos = self.env.reset(seed=seed, options=options)
|
|
|
|
shield = self.shield_creator.create_shield(env=self.env)
|
|
|
|
self.shield = shield
|
|
return obs["image"], infos
|
|
|
|
def step(self, action):
|
|
orig_obs, rew, done, truncated, info = self.env.step(action)
|
|
obs = orig_obs["image"]
|
|
|
|
return obs, rew, done, truncated, info
|
|
|