Browse Source

simple masking (only turn left allowed)

refactoring
Thomas Knoll 1 year ago
parent
commit
f52262ad11
  1. 39
      examples/shields/rl/11_minigridrl.py
  2. 23
      examples/shields/rl/MaskModels.py
  3. 47
      examples/shields/rl/Wrapper.py

39
examples/shields/rl/11_minigridrl.py

@ -82,6 +82,7 @@ def parse_arguments(argparse):
parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
parser.add_argument("--grid_path", default="Grid.txt")
parser.add_argument("--prism_path", default="Grid.PRISM")
parser.add_argument("--no_masking", default=False)
args = parser.parse_args()
@ -92,14 +93,14 @@ def env_creater_custom(config):
# name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
# # name = config.get("name", "MiniGrid-Empty-8x8-v0")
framestack = config.get("framestack", 4)
shield = config.get("shield", {})
# env = gym.make(name)
# env = ParametricActionsMiniGridEnv(config)
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
env = gym.make(name)
env = MiniGridEnvWrapper(env)
env = MiniGridEnvWrapper(env, shield=shield)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
@ -163,10 +164,21 @@ def create_shield(grid_file, prism_path):
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return shield.construct(), model
return action_dictionary
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
@ -195,13 +207,13 @@ def main():
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield, model = create_shield(grid_file, prism_path)
shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
shield_dict = create_shield(grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
print(shield_dict)
for state_id in model.states:
choices = shield.get_choice(state_id)
print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
@ -213,14 +225,14 @@ def main():
config = (PPOConfig()
.rollouts(num_rollout_workers=1)
.resources(num_gpus=0)
.environment(env="mini-grid")
.environment(env="mini-grid", env_config={"shield": shield_dict })
.framework("torch")
.experimental(_disable_preprocessor_api=False)
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": True}
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
# "fcnet_hiddens": [256,256],
# "fcnet_activation": "relu",
@ -231,9 +243,6 @@ def main():
config.build()
)
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)

23
examples/shields/rl/MaskModels.py

@ -28,6 +28,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
print(F"Original Space is: {orig_space}")
#print(model_config)
print(F"Observation space in model: {obs_space}")
print(F"Provided action space in model {action_space}")
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -37,6 +38,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
assert("shield" in custom_config)
self.shield = custom_config["shield"]
self.count = 0
self.internal_model = TorchFC(
orig_space["data"],
@ -54,28 +56,31 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
# print(F"Input dict is {input_dict} at obs: {input_dict['obs']}")
# print(F"State is {state}")
action_mask = []
# print(F"State is {state}")
# print(input_dict["env"])
# Compute the unmasked logits.
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
# print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}")
action_mask = input_dict["obs"]["avail_actions"]
#print(F"Action mask is {action_mask} with dimension {action_mask.size()}")
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
assert(False)
return logits, state
# assert(False)
# Convert action_mask into a [0.0 || -inf]-type mask.
# inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
# masked_logits = logits + inf_mask
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")
# # Return masked logits.
# return masked_logits, state
return masked_logits, state
def value_function(self):
return self.internal_model.value_function()

47
examples/shields/rl/Wrapper.py

@ -26,7 +26,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
self.observation_space = Dict(
{
"data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
"avail_actions": gym.spaces.Box(0, 10, shape=(10,), dtype=int),
"avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
}
)
@ -82,34 +82,63 @@ class OneHotWrapper(gym.core.ObservationWrapper):
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env):
def __init__(self, env, shield):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
{
"data": env.observation_space.spaces["image"],
"avail_actions" : Box(0, 10, shape=(10,), dtype=np.int8),
"avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.shield = shield
def create_action_mask(self):
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
def test(self):
print("Testing some stuff")
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set everything to one
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
# if cur_pos_str in self.shield:
# allowed_actions = self.shield[cur_pos_str]
# for allowed_action in allowed_actions:
# index = allowed_action[0]
# mask[index] = 1.0
# else:
# for index in len(mask):
# mask[index] = 1.0
print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}")
mask[0] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset()
return {
"data": obs["image"],
"avail_actions": np.array([0.0] * 10, dtype=np.int8)
"avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8)
}, infos
def step(self, action):
print(F"Performed action in step: {action}")
orig_obs, rew, done, truncated, info = self.env.step(action)
self.test()
actions = self.create_action_mask()
#print(F"Original observation is {orig_obs}")
obs = {
"data": orig_obs["image"],
"avail_actions": np.array([0.0] * 10, dtype=np.int8),
"avail_actions": actions,
}
#print(F"Info is {info}")

Loading…
Cancel
Save