Browse Source

some fixes to key and door handling

refactoring
Thomas Knoll 1 year ago
parent
commit
906f251401
  1. 7
      examples/shields/rl/11_minigridrl.py
  2. 6
      examples/shields/rl/helpers.py
  3. 2
      examples/shields/rl/shieldhandlers.py
  4. 16
      examples/shields/rl/wrappers.py

7
examples/shields/rl/11_minigridrl.py

@ -24,9 +24,10 @@ def shielding_env_creater(config):
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator,
shield_query_creator=create_shield_query,
create_shield_at_reset=args.shield_creation_at_reset)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env, env = OneHotShieldingWrapper(env,
@ -79,6 +80,8 @@ def ppo(args):
checkpoint_dir = algo.save() checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}") print(f"Checkpoint saved in directory {checkpoint_dir}")
algo.save()
def dqn(args): def dqn(args):
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)

6
examples/shields/rl/helpers.py

@ -71,8 +71,12 @@ def get_action_index_mapping(actions):
return Actions.done return Actions.done
elif "drop" in action_str: elif "drop" in action_str:
return Actions.drop return Actions.drop
elif "toggle" in action_str:
return Actions.toggle
elif "unlock" in action_str:
return Actions.toggle
raise ValueError(F"Action string {action_str} not supported")
return Actions.done

2
examples/shields/rl/shieldhandlers.py

@ -66,7 +66,7 @@ class MiniGridShieldHandler(ShieldHandler):
assert result.has_scheduler assert result.has_scheduler
assert result.has_shield assert result.has_shield
shield = result.shield shield = result.shield
# stormpy.shields.export_shield(model, shield, "Grid.shield")
stormpy.shields.export_shield(model, shield, "Grid.shield")
action_dictionary = {} action_dictionary = {}
shield_scheduler = shield.construct() shield_scheduler = shield.construct()

16
examples/shields/rl/wrappers.py

@ -8,7 +8,7 @@ from gymnasium.spaces import Dict, Box
from collections import deque from collections import deque
from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping, extract_keys
from helpers import get_action_index_mapping
from shieldhandlers import ShieldHandler from shieldhandlers import ShieldHandler
@ -67,7 +67,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
image = obs["data"] image = obs["data"]
# One-hot the last dim into 12, 6, 3 one-hot vectors, then flatten.
# One-hot the last dim into 16, 6, 3 one-hot vectors, then flatten.
objects = one_hot(image[:, :, 0], depth=16) objects = one_hot(image[:, :, 0], depth=16)
colors = one_hot(image[:, :, 1], depth=6) colors = one_hot(image[:, :, 1], depth=6)
states = one_hot(image[:, :, 2], depth=3) states = one_hot(image[:, :, 2], depth=3)
@ -108,7 +108,9 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
return np.array([1.0] * self.max_available_actions, dtype=np.int8) return np.array([1.0] * self.max_available_actions, dtype=np.int8)
cur_pos_str = self.shield_query_creator(self.env) cur_pos_str = self.shield_query_creator(self.env)
# print(F"Pos string {cur_pos_str}")
# print(F"Shield {list(self.shield.keys())[0]}")
# print(cur_pos_str in self.shield)
# Create the mask # Create the mask
# If shield restricts action mask only valid with 1.0 # If shield restricts action mask only valid with 1.0
# else set all actions as valid # else set all actions as valid
@ -120,6 +122,8 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
for allowed_action in allowed_actions: for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set index = get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
if index is None: if index is None:
print(F"No mapping for action {list(allowed_action.labels)}")
print(F"Shield at pos {cur_pos_str}, shield {self.shield[cur_pos_str]}")
assert(False) assert(False)
allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0] allowed = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
@ -134,8 +138,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if front_tile is not None and front_tile.type == "key": if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0 mask[Actions.pickup] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door": if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0 mask[Actions.toggle] = 1.0
@ -148,7 +150,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if self.create_shield_at_reset and self.mask_actions: if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env) self.shield = self.shield_creator.create_shield(env=self.env)
self.keys = extract_keys(self.env)
mask = self.create_action_mask() mask = self.create_action_mask()
return { return {
"data": obs["image"], "data": obs["image"],
@ -164,7 +165,6 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
"action_mask": mask, "action_mask": mask,
} }
#print(F"Info is {info}")
return obs, rew, done, truncated, info return obs, rew, done, truncated, info
@ -222,10 +222,8 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = self.shield_creator.create_shield(env=self.env) shield = self.shield_creator.create_shield(env=self.env)
self.keys = keys
self.shield = shield self.shield = shield
return obs["image"], infos return obs["image"], infos

Loading…
Cancel
Save