diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 035b195..ff7d8e1 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -5,7 +5,7 @@ from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.policy import Policy from ray.rllib.utils.typing import PolicyID - +from ray.rllib.algorithms.algorithm import Algorithm import gymnasium as gym @@ -29,9 +29,6 @@ from helpers import extract_keys, parse_arguments, create_shield_dict, create_lo import matplotlib.pyplot as plt - - - class MyCallbacks(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") @@ -50,7 +47,7 @@ class MyCallbacks(DefaultCallbacks): def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] - #print(env.printGrid()) + # print(env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") @@ -65,10 +62,9 @@ def env_creater_custom(config): shield = config.get("shield", {}) name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) - + args = config.get("args", None) env = gym.make(name) - keys = extract_keys(env) - env = MiniGridEnvWrapper(env, shield=shield, keys=keys) + env = MiniGridEnvWrapper(env, args=args) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotWrapper(env, @@ -76,6 +72,7 @@ def env_creater_custom(config): framestack=framestack ) + return env @@ -96,12 +93,11 @@ def ppo(args): register_custom_minigrid_env(args) - shield_dict = create_shield_dict(args) config = (PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) - .environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env}) + .environment(env="mini-grid", env_config={"name": args.env, "args": args}) .framework("torch") .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) @@ -111,7 +107,7 @@ def ppo(args): }) .training(_enable_learner_api=False ,model={ "custom_model": "pa_model", - "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} + "custom_model_config" : {"no_masking": args.no_masking} })) algo =( @@ -119,11 +115,7 @@ def ppo(args): config.build() ) - # while not terminated and not truncated: - # action = algo.compute_single_action(obs) - # obs, reward, terminated, truncated = env.step(action) - - for i in range(30): + for i in range(args.iterations): result = algo.train() print(pretty_print(result)) @@ -131,18 +123,24 @@ def ppo(args): checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") + # terminated = truncated = False + + # while not terminated and not truncated: + # action = algo.compute_single_action(obs) + # obs, reward, terminated, truncated = env.step(action) + + ray.shutdown() def dqn(args): register_custom_minigrid_env(args) - shield_dict = create_shield_dict(args) config = DQNConfig() config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=1) - config = config.environment(env="mini-grid", env_config={"shield": shield_dict, "name": args.env }) + config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.framework("torch") config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) @@ -152,7 +150,7 @@ def dqn(args): }) config = config.training(hiddens=[], dueling=False, model={ "custom_model": "pa_model", - "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} + "custom_model_config" : {"no_masking": args.no_masking} }) algo = ( diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 5873ed6..6a2fb20 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -27,13 +27,12 @@ class CustomCallback(BaseCallback): class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env, shield={}, keys=[], no_masking=False): + def __init__(self, env, args=None, no_masking=False): super(MiniGridEnvWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = env.observation_space.spaces["image"] - self.keys = keys - self.shield = shield + self.args = args self.no_masking = no_masking def create_action_mask(self): @@ -94,6 +93,12 @@ class MiniGridEnvWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) + + keys = extract_keys(self.env) + shield = create_shield_dict(self.env, self.args) + + self.keys = keys + self.shield = shield return obs["image"], infos def step(self, action): @@ -116,11 +121,10 @@ def mask_fn(env: gym.Env): def main(): import argparse args = parse_arguments(argparse) - shield = create_shield_dict(args) + env = gym.make(args.env, render_mode="rgb_array") - keys = extract_keys(env) - env = MiniGridEnvWrapper(env, shield=shield, keys=keys, no_masking=args.no_masking) + env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking) env = ActionMasker(env, mask_fn) callback = CustomCallback(1, env) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index 607b529..0ee4154 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -34,9 +34,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): ) nn.Module.__init__(self) - assert("shield" in custom_config) - - self.shield = custom_config["shield"] self.count = 0 self.internal_model = TorchFC( diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index 41aaf2a..6650369 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -7,7 +7,7 @@ from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot -from helpers import get_action_index_mapping +from helpers import get_action_index_mapping, create_shield_dict, extract_keys class OneHotWrapper(gym.core.ObservationWrapper): @@ -86,7 +86,7 @@ class OneHotWrapper(gym.core.ObservationWrapper): class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env, shield={}, keys=[]): + def __init__(self, env, args=None): super(MiniGridEnvWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( @@ -95,8 +95,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) - self.keys = keys - self.shield = shield + self.args = args def create_action_mask(self): coordinates = self.env.agent_pos @@ -140,8 +139,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper): if front_tile is not None and front_tile.type == "key": mask[Actions.pickup] = 1.0 - if self.env.carrying: - mask[Actions.drop] = 1.0 + # if self.env.carrying: + # mask[Actions.drop] = 1.0 if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 @@ -150,6 +149,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) + self.shield = create_shield_dict(self.env, self.args) + self.keys = extract_keys(self.env) mask = self.create_action_mask() return { "data": obs["image"], diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 714ac00..ca15f14 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -20,7 +20,7 @@ import os def extract_keys(env): env.reset() keys = [] - print(env.grid) + #print(env.grid) for j in range(env.grid.height): for i in range(env.grid.width): obj = env.grid.get(i,j) @@ -113,8 +113,8 @@ def create_shield(grid_to_prism_path, grid_file, prism_path): program = stormpy.parse_prism_program(prism_path) - # formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" - formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" + formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" + # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" # shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, # stormpy.logic.ShieldComparison.ABSOLUTE, 0.9) @@ -150,7 +150,7 @@ def create_shield(grid_to_prism_path, grid_file, prism_path): return action_dictionary -def create_shield_dict(args): +def create_shield_dict(env, args): env = create_environment(args) # print(env.printGrid(init=False))