You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

217 lines
8.5 KiB

11 months ago
11 months ago
11 months ago
  1. import gymnasium as gym
  2. import numpy as np
  3. import random
  4. from minigrid.core.actions import Actions
  5. from minigrid.core.constants import COLORS, OBJECT_TO_IDX, STATE_TO_IDX
  6. from gymnasium.spaces import Dict, Box
  7. from collections import deque
  8. from ray.rllib.utils.numpy import one_hot
  9. from utils import get_action_index_mapping, MiniGridShieldHandler, create_shield_query, ShieldingConfig
  10. class OneHotShieldingWrapper(gym.core.ObservationWrapper):
  11. def __init__(self, env, vector_index, framestack):
  12. super().__init__(env)
  13. self.framestack = framestack
  14. # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
  15. # +4: Direction.
  16. self.single_frame_dim = 49 * (len(OBJECT_TO_IDX) + len(COLORS) + len(STATE_TO_IDX)) + 4
  17. self.init_x = None
  18. self.init_y = None
  19. self.x_positions = []
  20. self.y_positions = []
  21. self.x_y_delta_buffer = deque(maxlen=100)
  22. self.vector_index = vector_index
  23. self.frame_buffer = deque(maxlen=self.framestack)
  24. for _ in range(self.framestack):
  25. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  26. self.observation_space = Dict(
  27. {
  28. "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
  29. "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
  30. }
  31. )
  32. def observation(self, obs):
  33. # Debug output: max-x/y positions to watch exploration progress.
  34. # print(F"Initial observation in Wrapper {obs}")
  35. if self.step_count == 0:
  36. for _ in range(self.framestack):
  37. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  38. if self.vector_index == 0:
  39. if self.x_positions:
  40. max_diff = max(
  41. np.sqrt(
  42. (np.array(self.x_positions) - self.init_x) ** 2
  43. + (np.array(self.y_positions) - self.init_y) ** 2
  44. )
  45. )
  46. self.x_y_delta_buffer.append(max_diff)
  47. print(
  48. "100-average dist travelled={}".format(
  49. np.mean(self.x_y_delta_buffer)
  50. )
  51. )
  52. self.x_positions = []
  53. self.y_positions = []
  54. self.init_x = self.agent_pos[0]
  55. self.init_y = self.agent_pos[1]
  56. self.x_positions.append(self.agent_pos[0])
  57. self.y_positions.append(self.agent_pos[1])
  58. image = obs["data"]
  59. # One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
  60. objects = one_hot(image[:, :, 0], depth=len(OBJECT_TO_IDX))
  61. colors = one_hot(image[:, :, 1], depth=len(COLORS))
  62. states = one_hot(image[:, :, 2], depth=len(STATE_TO_IDX))
  63. all_ = np.concatenate([objects, colors, states], -1)
  64. all_flat = np.reshape(all_, (-1,))
  65. direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
  66. single_frame = np.concatenate([all_flat, direction])
  67. self.frame_buffer.append(single_frame)
  68. tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
  69. return tmp
  70. class MiniGridShieldingWrapper(gym.core.Wrapper):
  71. def __init__(self,
  72. env,
  73. shield_creator : MiniGridShieldHandler,
  74. shield_query_creator,
  75. create_shield_at_reset=True,
  76. mask_actions=True):
  77. super(MiniGridShieldingWrapper, self).__init__(env)
  78. self.max_available_actions = env.action_space.n
  79. self.observation_space = Dict(
  80. {
  81. "data": env.observation_space.spaces["image"],
  82. "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
  83. }
  84. )
  85. self.shield_creator = shield_creator
  86. self.create_shield_at_reset = create_shield_at_reset
  87. self.shield = shield_creator.create_shield(env=self.env)
  88. self.mask_actions = mask_actions
  89. self.shield_query_creator = shield_query_creator
  90. print(F"Shielding is {self.mask_actions}")
  91. def create_action_mask(self):
  92. if not self.mask_actions:
  93. ret = np.array([1.0] * self.max_available_actions, dtype=np.int8)
  94. return ret
  95. cur_pos_str = self.shield_query_creator(self.env)
  96. # Create the mask
  97. # If shield restricts action mask only valid with 1.0
  98. # else set all actions as valid
  99. allowed_actions = []
  100. mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
  101. if cur_pos_str in self.shield and self.shield[cur_pos_str]:
  102. allowed_actions = self.shield[cur_pos_str]
  103. zeroes = np.array([0.0] * len(allowed_actions), dtype=np.int8)
  104. has_allowed_actions = False
  105. for allowed_action in allowed_actions:
  106. index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
  107. if index is None:
  108. assert(False)
  109. allowed = 1.0
  110. has_allowed_actions = True
  111. mask[index] = allowed
  112. else:
  113. for index, x in enumerate(mask):
  114. mask[index] = 1.0
  115. front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
  116. if front_tile is not None and front_tile.type == "key":
  117. mask[Actions.pickup] = 1.0
  118. if front_tile and front_tile.type == "door":
  119. mask[Actions.toggle] = 1.0
  120. # print(F"Mask is {mask} State: {cur_pos_str}")
  121. return mask
  122. def reset(self, *, seed=None, options=None):
  123. obs, infos = self.env.reset(seed=seed, options=options)
  124. if self.create_shield_at_reset and self.mask_actions:
  125. self.shield = self.shield_creator.create_shield(env=self.env)
  126. mask = self.create_action_mask()
  127. return {
  128. "data": obs["image"],
  129. "action_mask": mask
  130. }, infos
  131. def step(self, action):
  132. orig_obs, rew, done, truncated, info = self.env.step(action)
  133. mask = self.create_action_mask()
  134. obs = {
  135. "data": orig_obs["image"],
  136. "action_mask": mask,
  137. }
  138. return obs, rew, done, truncated, info
  139. def shielding_env_creater(config):
  140. name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0")
  141. framestack = config.get("framestack", 4)
  142. args = config.get("args", None)
  143. args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt"
  144. args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"
  145. shielding = config.get("shielding", False)
  146. shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,
  147. grid_to_prism_path=args.grid_to_prism_binary_path,
  148. prism_path=args.prism_path,
  149. formula=args.formula,
  150. shield_value=args.shield_value,
  151. prism_config=args.prism_config,
  152. shield_comparision=args.shield_comparision)
  153. probability_intended = args.probability_intended
  154. probability_displacement = args.probability_displacement
  155. probability_turn_intended = args.probability_turn_intended
  156. probability_turn_displacement = args.probability_turn_displacement
  157. env = gym.make(name,
  158. randomize_start=True,
  159. probability_intended=probability_intended,
  160. probability_displacement=probability_displacement,
  161. probability_turn_displacement=probability_turn_displacement,
  162. probability_turn_intended=probability_turn_intended)
  163. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
  164. env = OneHotShieldingWrapper(env,
  165. config.vector_index if hasattr(config, "vector_index") else 0,
  166. framestack=framestack
  167. )
  168. return env
  169. def register_minigrid_shielding_env(args):
  170. env_name = "mini-grid-shielding"
  171. register_env(env_name, shielding_env_creater)
  172. ModelCatalog.register_custom_model(
  173. "shielding_model",
  174. TorchActionMaskModel
  175. )