You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
import gymnasium as gym import numpy as np import random
from utils import MiniGridShieldHandler, common_parser
class MiniGridSbShieldingWrapper(gym.core.Wrapper): def __init__(self, env, shield_handler : MiniGridShieldHandler, create_shield_at_reset = True, mask_actions=True, ): super().__init__(env) self.observation_space = env.observation_space.spaces["image"]
self.shield_handler = shield_handler self.mask_actions = mask_actions self.create_shield_at_reset = create_shield_at_reset
shield = self.shield_handler.create_shield(env=self.env) self.shield = shield
def create_action_mask(self): try: return self.shield[self.env.get_symbolic_state()] except: return [1.0] * 3 + [1.0] * 4
def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset and self.mask_actions: shield = self.shield_handler.create_shield(env=self.env) self.shield = shield return obs["image"], infos
def step(self, action): orig_obs, rew, done, truncated, info = self.env.step(action) obs = orig_obs["image"]
return obs, rew, done, truncated, info
def parse_sb3_arguments(): parser = common_parser() args = parser.parse_args()
return args
|