You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.4 KiB

2 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from utils import MiniGridShieldHandler, create_shield_query
  5. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  6. def __init__(self,
  7. env,
  8. shield_creator : MiniGridShieldHandler,
  9. shield_query_creator,
  10. create_shield_at_reset = True,
  11. mask_actions=True,
  12. ):
  13. super(MiniGridSbShieldingWrapper, self).__init__(env)
  14. self.max_available_actions = env.action_space.n
  15. self.observation_space = env.observation_space.spaces["image"]
  16. self.shield_creator = shield_creator
  17. self.mask_actions = mask_actions
  18. self.shield_query_creator = shield_query_creator
  19. def create_action_mask(self):
  20. if not self.mask_actions:
  21. return np.array([1.0] * self.max_available_actions, dtype=np.int8)
  22. cur_pos_str = self.shield_query_creator(self.env)
  23. allowed_actions = []
  24. # Create the mask
  25. # If shield restricts actions, mask only valid actions with 1.0
  26. # else set all actions valid
  27. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  28. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  29. allowed_actions = self.shield[cur_pos_str]
  30. for allowed_action in allowed_actions:
  31. index = get_action_index_mapping(allowed_action.labels)
  32. if index is None:
  33. assert(False)
  34. mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
  35. else:
  36. for index, x in enumerate(mask):
  37. mask[index] = 1.0
  38. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  39. if front_tile and front_tile.type == "door":
  40. mask[Actions.toggle] = 1.0
  41. return mask
  42. def reset(self, *, seed=None, options=None):
  43. obs, infos = self.env.reset(seed=seed, options=options)
  44. shield = self.shield_creator.create_shield(env=self.env)
  45. self.shield = shield
  46. return obs["image"], infos
  47. def step(self, action):
  48. orig_obs, rew, done, truncated, info = self.env.step(action)
  49. obs = orig_obs["image"]
  50. return obs, rew, done, truncated, info