From fe96a6a0b6e8207e2d5d00da4bb7d5d1cb2e2a04 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Thu, 24 Aug 2023 11:01:11 +0200 Subject: [PATCH] added dqn algorithm --- examples/shields/rl/11_minigridrl.py | 47 +++++++++++++++------------- examples/shields/rl/MaskModels.py | 3 +- examples/shields/rl/Wrapper.py | 6 ++-- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 5697f63..fea882a 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -30,7 +30,6 @@ from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterat from ray import tune, air from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print -from ray.rllib.algorithms import ppo from ray.rllib.models import ModelCatalog from ray.rllib.utils.torch_utils import FLOAT_MIN @@ -196,14 +195,15 @@ def create_environment(args): return env -def register_custom_minigrid_env(): +def register_custom_minigrid_env(args): env_name = "mini-grid" register_env(env_name, env_creater_custom) + ModelCatalog.register_custom_model( "pa_model", TorchActionMaskModel ) - + def create_shield_dict(args): env = create_environment(args) # print(env.pprint_grid()) @@ -216,7 +216,7 @@ def create_shield_dict(args): shield_dict = create_shield(grid_file, prism_path) #shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} - print(F"Shield dictionary {shield_dict}") + #print(F"Shield dictionary {shield_dict}") # for state_id in model.states: # choices = shield.get_choice(state_id) # print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") @@ -228,7 +228,7 @@ def ppo(args): ray.init(num_cpus=3) - register_custom_minigrid_env() + register_custom_minigrid_env(args) shield_dict = create_shield_dict(args) config = (PPOConfig() @@ -268,30 +268,35 @@ def ppo(args): def dqn(args): - config = DQNConfig() - register_custom_minigrid_env() + register_custom_minigrid_env(args) shield_dict = create_shield_dict(args) - replay_config = config.replay_buffer_config.update( - { - "capacity": 60000, - "prioritized_replay_alpha": 0.5, - "prioritized_replay_beta": 0.5, - "prioritized_replay_eps": 3e-6, - } - ) + - config = config.training(replay_buffer_config=replay_config, model={ - "custom_model": "pa_model", - "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} - }) + config = DQNConfig() config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=1) + config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) config = config.framework("torch") config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) + config = config.training(hiddens=[], dueling=False, model={ + "custom_model": "pa_model", + "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} + }) - config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) - + algo = ( + config.build() + ) + + for i in range(30): + result = algo.train() + print(pretty_print(result)) + + if i % 5 == 0: + checkpoint_dir = algo.save() + print(f"Checkpoint saved in directory {checkpoint_dir}") + + ray.shutdown() def main(): diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/MaskModels.py index 4d9baaf..e36c6da 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/MaskModels.py @@ -1,5 +1,4 @@ from typing import Dict, Optional, Union - from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.models.tf.fcnet import FullyConnectedNetwork @@ -75,6 +74,8 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): # assert(False) # Convert action_mask into a [0.0 || -inf]-type mask. inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) + # print(F"Logits Size: {logits.size()} Inf-Mask Size: {inf_mask.size()}") + # print(F"Logits:{logits} Inf-Mask: {inf_mask}") masked_logits = logits + inf_mask # print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}") diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrapper.py index 6058dca..f4cef3f 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrapper.py @@ -109,7 +109,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper): # else set everything to one mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - if cur_pos_str in self.shield: + if cur_pos_str in self.shield and self.shield[cur_pos_str]: allowed_actions = self.shield[cur_pos_str] for allowed_action in allowed_actions: index = allowed_action[0] @@ -118,8 +118,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper): for index, x in enumerate(mask): mask[index] = 1.0 - - #print(F"Action Mask for position {coordinates} and view {view_direction} is {mask}") + + #print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})") return mask