Browse Source

changed shield creation to create shield on reset

refactoring
Thomas Knoll 1 year ago
parent
commit
1c2dbf706e
  1. 36
      examples/shields/rl/11_minigridrl.py
  2. 16
      examples/shields/rl/13_minigridsb.py
  3. 3
      examples/shields/rl/MaskModels.py
  4. 13
      examples/shields/rl/Wrapper.py
  5. 8
      examples/shields/rl/helpers.py

36
examples/shields/rl/11_minigridrl.py

@ -5,7 +5,7 @@ from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
import gymnasium as gym import gymnasium as gym
@ -29,9 +29,6 @@ from helpers import extract_keys, parse_arguments, create_shield_dict, create_lo
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}") # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
@ -50,7 +47,7 @@ class MyCallbacks(DefaultCallbacks):
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1 episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
#print(env.printGrid())
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}") # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
@ -65,10 +62,9 @@ def env_creater_custom(config):
shield = config.get("shield", {}) shield = config.get("shield", {})
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4) framestack = config.get("framestack", 4)
args = config.get("args", None)
env = gym.make(name) env = gym.make(name)
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys)
env = MiniGridEnvWrapper(env, args=args)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotWrapper(env, env = OneHotWrapper(env,
@ -76,6 +72,7 @@ def env_creater_custom(config):
framestack=framestack framestack=framestack
) )
return env return env
@ -96,12 +93,11 @@ def ppo(args):
register_custom_minigrid_env(args) register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=1) .rollouts(num_rollout_workers=1)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env})
.environment(env="mini-grid", env_config={"name": args.env, "args": args})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks) .callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
@ -111,7 +107,7 @@ def ppo(args):
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "pa_model", "custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
"custom_model_config" : {"no_masking": args.no_masking}
})) }))
algo =( algo =(
@ -119,11 +115,7 @@ def ppo(args):
config.build() config.build()
) )
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
for i in range(30):
for i in range(args.iterations):
result = algo.train() result = algo.train()
print(pretty_print(result)) print(pretty_print(result))
@ -131,18 +123,24 @@ def ppo(args):
checkpoint_dir = algo.save() checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}") print(f"Checkpoint saved in directory {checkpoint_dir}")
# terminated = truncated = False
# while not terminated and not truncated:
# action = algo.compute_single_action(obs)
# obs, reward, terminated, truncated = env.step(action)
ray.shutdown() ray.shutdown()
def dqn(args): def dqn(args):
register_custom_minigrid_env(args) register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = DQNConfig() config = DQNConfig()
config = config.resources(num_gpus=0) config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1) config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env })
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
config = config.callbacks(MyCallbacks) config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
@ -152,7 +150,7 @@ def dqn(args):
}) })
config = config.training(hiddens=[], dueling=False, model={ config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model", "custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
"custom_model_config" : {"no_masking": args.no_masking}
}) })
algo = ( algo = (

16
examples/shields/rl/13_minigridsb.py

@ -27,13 +27,12 @@ class CustomCallback(BaseCallback):
class MiniGridEnvWrapper(gym.core.Wrapper): class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}, keys=[], no_masking=False):
def __init__(self, env, args=None, no_masking=False):
super(MiniGridEnvWrapper, self).__init__(env) super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"] self.observation_space = env.observation_space.spaces["image"]
self.keys = keys
self.shield = shield
self.args = args
self.no_masking = no_masking self.no_masking = no_masking
def create_action_mask(self): def create_action_mask(self):
@ -94,6 +93,12 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = create_shield_dict(self.env, self.args)
self.keys = keys
self.shield = shield
return obs["image"], infos return obs["image"], infos
def step(self, action): def step(self, action):
@ -116,11 +121,10 @@ def mask_fn(env: gym.Env):
def main(): def main():
import argparse import argparse
args = parse_arguments(argparse) args = parse_arguments(argparse)
shield = create_shield_dict(args)
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
keys = extract_keys(env)
env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking)
env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env) callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))

3
examples/shields/rl/MaskModels.py

@ -34,9 +34,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
) )
nn.Module.__init__(self) nn.Module.__init__(self)
assert("shield" in custom_config)
self.shield = custom_config["shield"]
self.count = 0 self.count = 0
self.internal_model = TorchFC( self.internal_model = TorchFC(

13
examples/shields/rl/Wrapper.py

@ -7,7 +7,7 @@ from gymnasium.spaces import Dict, Box
from collections import deque from collections import deque
from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping
from helpers import get_action_index_mapping, create_shield_dict, extract_keys
class OneHotWrapper(gym.core.ObservationWrapper): class OneHotWrapper(gym.core.ObservationWrapper):
@ -86,7 +86,7 @@ class OneHotWrapper(gym.core.ObservationWrapper):
class MiniGridEnvWrapper(gym.core.Wrapper): class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, shield={}, keys=[]):
def __init__(self, env, args=None):
super(MiniGridEnvWrapper, self).__init__(env) super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = Dict( self.observation_space = Dict(
@ -95,8 +95,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
} }
) )
self.keys = keys
self.shield = shield
self.args = args
def create_action_mask(self): def create_action_mask(self):
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
@ -140,8 +139,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
if front_tile is not None and front_tile.type == "key": if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0 mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door": if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0 mask[Actions.toggle] = 1.0
@ -150,6 +149,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
self.shield = create_shield_dict(self.env, self.args)
self.keys = extract_keys(self.env)
mask = self.create_action_mask() mask = self.create_action_mask()
return { return {
"data": obs["image"], "data": obs["image"],

8
examples/shields/rl/helpers.py

@ -20,7 +20,7 @@ import os
def extract_keys(env): def extract_keys(env):
env.reset() env.reset()
keys = [] keys = []
print(env.grid)
#print(env.grid)
for j in range(env.grid.height): for j in range(env.grid.height):
for i in range(env.grid.width): for i in range(env.grid.width):
obj = env.grid.get(i,j) obj = env.grid.get(i,j)
@ -113,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
program = stormpy.parse_prism_program(prism_path) program = stormpy.parse_prism_program(prism_path)
# formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
# formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
# shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY,
# stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9)
@ -150,7 +150,7 @@ def create_shield(grid_to_prism_path, grid_file, prism_path):
return action_dictionary return action_dictionary
def create_shield_dict(args):
def create_shield_dict(env, args):
env = create_environment(args) env = create_environment(args)
# print(env.printGrid(init=False)) # print(env.printGrid(init=False))

Loading…
Cancel
Save