You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.4 KiB

11 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from utils import MiniGridShieldHandler, common_parser
  5. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  6. def __init__(self,
  7. env,
  8. shield_handler : MiniGridShieldHandler,
  9. create_shield_at_reset = True,
  10. mask_actions=True,
  11. ):
  12. super().__init__(env)
  13. self.observation_space = env.observation_space.spaces["image"]
  14. self.shield_handler = shield_handler
  15. self.mask_actions = mask_actions
  16. self.create_shield_at_reset = create_shield_at_reset
  17. shield = self.shield_handler.create_shield(env=self.env)
  18. self.shield = shield
  19. def create_action_mask(self):
  20. try:
  21. return self.shield[self.env.get_symbolic_state()]
  22. except:
  23. return [1.0] * 3 + [1.0] * 4
  24. def reset(self, *, seed=None, options=None):
  25. obs, infos = self.env.reset(seed=seed, options=options)
  26. if self.create_shield_at_reset and self.mask_actions:
  27. shield = self.shield_handler.create_shield(env=self.env)
  28. self.shield = shield
  29. return obs["image"], infos
  30. def step(self, action):
  31. orig_obs, rew, done, truncated, info = self.env.step(action)
  32. obs = orig_obs["image"]
  33. return obs, rew, done, truncated, info
  34. def parse_sb3_arguments():
  35. parser = common_parser()
  36. args = parser.parse_args()
  37. return args