From e42becef886ecf276f41ea4fdbbd3cfb301aaafe Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Wed, 23 Aug 2023 16:13:49 +0200 Subject: [PATCH] added dqn handling skeleton --- examples/shields/rl/11_minigridrl.py | 84 +++++++++++++++++++------ examples/shields/rl/MaskEnvironments.py | 4 +- examples/shields/rl/MaskModels.py | 10 +-- examples/shields/rl/Wrapper.py | 73 ++++++--------------- 4 files changed, 92 insertions(+), 79 deletions(-) diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index f77c7e6..5697f63 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -25,6 +25,7 @@ import numpy as np import ray from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator from ray import tune, air from ray.rllib.algorithms.callbacks import DefaultCallbacks @@ -37,7 +38,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN from ray.rllib.models.preprocessors import get_preprocessor from MaskEnvironments import ParametricActionsMiniGridEnv from MaskModels import TorchActionMaskModel -from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper +from Wrapper import OneHotWrapper, MiniGridEnvWrapper import matplotlib.pyplot as plt @@ -62,7 +63,7 @@ class MyCallbacks(DefaultCallbacks): def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] - print(env.env.env.printGrid()) + #print(env.env.env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") @@ -83,6 +84,7 @@ def parse_arguments(argparse): parser.add_argument("--grid_path", default="Grid.txt") parser.add_argument("--prism_path", default="Grid.PRISM") parser.add_argument("--no_masking", default=False) + parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) args = parser.parse_args() @@ -108,13 +110,13 @@ def env_creater_custom(config): framestack=framestack ) - obs = env.observation_space.sample() - obs2, infos = env.reset(seed=None, options={}) + # obs = env.observation_space.sample() + # obs2, infos = env.reset(seed=None, options={}) - print(F"Obs is {obs} before reset. After reset: {obs2}") + # print(F"Obs is {obs} before reset. After reset: {obs2}") # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) - print(F"Created Custom Minigrid Environment is {env}") + # print(F"Created Custom Minigrid Environment is {env}") return env @@ -194,12 +196,16 @@ def create_environment(args): return env -def main(): - args = parse_arguments(argparse) +def register_custom_minigrid_env(): + env_name = "mini-grid" + register_env(env_name, env_creater_custom) + ModelCatalog.register_custom_model( + "pa_model", + TorchActionMaskModel + ) +def create_shield_dict(args): env = create_environment(args) - ray.init(num_cpus=3) - # print(env.pprint_grid()) # print(env.printGrid(init=False)) @@ -214,20 +220,22 @@ def main(): # for state_id in model.states: # choices = shield.get_choice(state_id) # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") - - env_name = "mini-grid" - register_env(env_name, env_creater_custom) - ModelCatalog.register_custom_model( - "pa_model", - TorchActionMaskModel - ) + + return shield_dict + +def ppo(args): + + ray.init(num_cpus=3) + + + register_custom_minigrid_env() + shield_dict = create_shield_dict(args) config = (PPOConfig() .rollouts(num_rollout_workers=1) .resources(num_gpus=0) .environment(env="mini-grid", env_config={"shield": shield_dict }) .framework("torch") - .experimental(_disable_preprocessor_api=False) .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .training(_enable_learner_api=False ,model={ @@ -255,10 +263,48 @@ def main(): if i % 5 == 0: checkpoint_dir = algo.save() print(f"Checkpoint saved in directory {checkpoint_dir}") + + ray.shutdown() + + +def dqn(args): + config = DQNConfig() + register_custom_minigrid_env() + shield_dict = create_shield_dict(args) + replay_config = config.replay_buffer_config.update( + { + "capacity": 60000, + "prioritized_replay_alpha": 0.5, + "prioritized_replay_beta": 0.5, + "prioritized_replay_eps": 3e-6, + } + ) + + config = config.training(replay_buffer_config=replay_config, model={ + "custom_model": "pa_model", + "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} + }) + config = config.resources(num_gpus=0) + config = config.rollouts(num_rollout_workers=1) + config = config.framework("torch") + config = config.callbacks(MyCallbacks) + config = config.rl_module(_enable_rl_module_api = False) + + config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) + + + +def main(): + args = parse_arguments(argparse) + + if args.algorithm == "ppo": + ppo(args) + elif args.algorithm == "dqn": + dqn(args) + - ray.shutdown() if __name__ == '__main__': main() \ No newline at end of file diff --git a/examples/shields/rl/MaskEnvironments.py b/examples/shields/rl/MaskEnvironments.py index ccd63ba..f1c8d16 100644 --- a/examples/shields/rl/MaskEnvironments.py +++ b/examples/shields/rl/MaskEnvironments.py @@ -56,7 +56,7 @@ class ParametricActionsMiniGridEnv(gym.Env): return obs, infos return { "action_mask": self.action_mask, - "avail_actions": self.action_assignments, + "avail_action": self.action_assignments, "cart": obs, }, infos @@ -83,7 +83,7 @@ class ParametricActionsMiniGridEnv(gym.Env): return orig_obs, rew, done, truncated, info obs = { "action_mask": self.action_mask, - "avail_actions": self.action_assignments, + "action_mask": self.action_assignments, "cart": orig_obs, } return obs, rew, done, truncated, info diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index 71b4418..4d9baaf 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -25,10 +25,10 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): ): orig_space = getattr(obs_space, "original_space", obs_space) custom_config = model_config['custom_model_config'] - print(F"Original Space is: {orig_space}") + # print(F"Original Space is: {orig_space}") #print(model_config) - print(F"Observation space in model: {obs_space}") - print(F"Provided action space in model {action_space}") + #print(F"Observation space in model: {obs_space}") + #print(F"Provided action space in model {action_space}") TorchModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs @@ -65,7 +65,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): # print(F"Caluclated Logits {logits} with size {logits.size()} Count: {self.count}") - action_mask = input_dict["obs"]["avail_actions"] + action_mask = input_dict["obs"]["action_mask"] #print(F"Action mask is {action_mask} with dimension {action_mask.size()}") # If action masking is disabled, directly return unmasked logits @@ -77,7 +77,7 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) masked_logits = logits + inf_mask - print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}") + # print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}") # # Return masked logits. return masked_logits, state diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index 183676f..6058dca 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -26,12 +26,12 @@ class OneHotWrapper(gym.core.ObservationWrapper): self.observation_space = Dict( { "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32), - "avail_actions": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), + "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int), } ) - print(F"Set obersvation space to {self.observation_space}") + # print(F"Set obersvation space to {self.observation_space}") def observation(self, obs): @@ -77,7 +77,7 @@ class OneHotWrapper(gym.core.ObservationWrapper): self.frame_buffer.append(single_frame) #obs["one-hot"] = np.concatenate(self.frame_buffer) - tmp = {"data": np.concatenate(self.frame_buffer), "avail_actions": obs["avail_actions"] } + tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] } return tmp#np.concatenate(self.frame_buffer) @@ -88,7 +88,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): self.observation_space = Dict( { "data": env.observation_space.spaces["image"], - "avail_actions" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), + "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) @@ -98,7 +98,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): def create_action_mask(self): coordinates = self.env.agent_pos view_direction = self.env.agent_dir - print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") + #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" allowed_actions = [] @@ -109,73 +109,40 @@ class MiniGridEnvWrapper(gym.core.Wrapper): # else set everything to one mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - # if cur_pos_str in self.shield: - # allowed_actions = self.shield[cur_pos_str] - # for allowed_action in allowed_actions: - # index = allowed_action[0] - # mask[index] = 1.0 - # else: - # for index in len(mask): - # mask[index] = 1.0 + if cur_pos_str in self.shield: + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = allowed_action[0] + mask[index] = 1.0 + else: + for index, x in enumerate(mask): + mask[index] = 1.0 - print(F"Allowed actions for position {coordinates} and view {view_direction} are {allowed_actions}") - mask[0] = 1.0 + #print(F"Action Mask for position {coordinates} and view {view_direction} is {mask}") + return mask def reset(self, *, seed=None, options=None): obs, infos = self.env.reset() + mask = self.create_action_mask() return { "data": obs["image"], - "avail_actions": np.array([0.0] * self.max_available_actions, dtype=np.int8) + "action_mask": mask }, infos def step(self, action): - print(F"Performed action in step: {action}") + # print(F"Performed action in step: {action}") orig_obs, rew, done, truncated, info = self.env.step(action) - actions = self.create_action_mask() + mask = self.create_action_mask() #print(F"Original observation is {orig_obs}") obs = { "data": orig_obs["image"], - "avail_actions": actions, + "action_mask": mask, } #print(F"Info is {info}") return obs, rew, done, truncated, info - - -class ImgObsWrapper(gym.core.ObservationWrapper): - """ - Use the image as the only observation output, no language/mission. - - Example: - >>> import gymnasium as gym - >>> from minigrid.wrappers import ImgObsWrapper - >>> env = gym.make("MiniGrid-Empty-5x5-v0") - >>> obs, _ = env.reset() - >>> obs.keys() - dict_keys(['image', 'direction', 'mission']) - >>> env = ImgObsWrapper(env) - >>> obs, _ = env.reset() - >>> obs.shape - (7, 7, 3) - """ - - def __init__(self, env): - """A wrapper that makes image the only observation. - - Args: - env: The environment to apply the wrapper - """ - super().__init__(env) - self.observation_space = env.observation_space.spaces["image"] - print(F"Set obersvation space to {self.observation_space}") - - def observation(self, obs): - #print(F"obs in img obs wrapper {obs}") - tmp = {"data": obs["image"], "Test": obs["Test"]} - - return tmp