You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

70 lines
2.4 KiB

11 months ago
11 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from utils import MiniGridShieldHandler, create_shield_query
  5. class MiniGridSbShieldingWrapper(gym.core.Wrapper):
  6. def __init__(self,
  7. env,
  8. shield_creator : MiniGridShieldHandler,
  9. shield_query_creator,
  10. create_shield_at_reset = True,
  11. mask_actions=True,
  12. ):
  13. super(MiniGridSbShieldingWrapper, self).__init__(env)
  14. self.max_available_actions = env.action_space.n
  15. self.observation_space = env.observation_space.spaces["image"]
  16. self.shield_creator = shield_creator
  17. self.mask_actions = mask_actions
  18. self.shield_query_creator = shield_query_creator
  19. def create_action_mask(self):
  20. if not self.mask_actions:
  21. return np.array([1.0] * self.max_available_actions, dtype=np.int8)
  22. cur_pos_str = self.shield_query_creator(self.env)
  23. allowed_actions = []
  24. # Create the mask
  25. # If shield restricts actions, mask only valid actions with 1.0
  26. # else set all actions valid
  27. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  28. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  29. allowed_actions = self.shield[cur_pos_str]
  30. for allowed_action in allowed_actions:
  31. index = get_action_index_mapping(allowed_action.labels)
  32. if index is None:
  33. assert(False)
  34. mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
  35. else:
  36. for index, x in enumerate(mask):
  37. mask[index] = 1.0
  38. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  39. if front_tile and front_tile.type == "door":
  40. mask[Actions.toggle] = 1.0
  41. return mask
  42. def reset(self, *, seed=None, options=None):
  43. obs, infos = self.env.reset(seed=seed, options=options)
  44. shield = self.shield_creator.create_shield(env=self.env)
  45. self.shield = shield
  46. return obs["image"], infos
  47. def step(self, action):
  48. orig_obs, rew, done, truncated, info = self.env.step(action)
  49. obs = orig_obs["image"]
  50. return obs, rew, done, truncated, info