Browse Source

added dqn algorithm

refactoring
Thomas Knoll 1 year ago
parent
commit
fe96a6a0b6
  1. 47
      examples/shields/rl/11_minigridrl.py
  2. 3
      examples/shields/rl/MaskModels.py
  3. 6
      examples/shields/rl/Wrapper.py

47
examples/shields/rl/11_minigridrl.py

@ -30,7 +30,6 @@ from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterat
from ray import tune, air
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.algorithms import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
@ -196,14 +195,15 @@ def create_environment(args):
return env
def register_custom_minigrid_env():
def register_custom_minigrid_env(args):
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
ModelCatalog.register_custom_model(
"pa_model",
TorchActionMaskModel
)
def create_shield_dict(args):
env = create_environment(args)
# print(env.pprint_grid())
@ -216,7 +216,7 @@ def create_shield_dict(args):
shield_dict = create_shield(grid_file, prism_path)
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states}
print(F"Shield dictionary {shield_dict}")
#print(F"Shield dictionary {shield_dict}")
# for state_id in model.states:
# choices = shield.get_choice(state_id)
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
@ -228,7 +228,7 @@ def ppo(args):
ray.init(num_cpus=3)
register_custom_minigrid_env()
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
config = (PPOConfig()
@ -268,30 +268,35 @@ def ppo(args):
def dqn(args):
config = DQNConfig()
register_custom_minigrid_env()
register_custom_minigrid_env(args)
shield_dict = create_shield_dict(args)
replay_config = config.replay_buffer_config.update(
{
"capacity": 60000,
"prioritized_replay_alpha": 0.5,
"prioritized_replay_beta": 0.5,
"prioritized_replay_eps": 3e-6,
}
)
config = config.training(replay_buffer_config=replay_config, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking}
})
config = config.environment(env="mini-grid", env_config={"shield": shield_dict })
algo = (
config.build()
)
for i in range(30):
result = algo.train()
print(pretty_print(result))
if i % 5 == 0:
checkpoint_dir = algo.save()
print(f"Checkpoint saved in directory {checkpoint_dir}")
ray.shutdown()
def main():

3
examples/shields/rl/MaskModels.py

@ -1,5 +1,4 @@
from typing import Dict, Optional, Union
from ray.rllib.algorithms.dqn.dqn_torch_model import DQNTorchModel
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
@ -75,6 +74,8 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
# assert(False)
# Convert action_mask into a [0.0 || -inf]-type mask.
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
# print(F"Logits Size: {logits.size()} Inf-Mask Size: {inf_mask.size()}")
# print(F"Logits:{logits} Inf-Mask: {inf_mask}")
masked_logits = logits + inf_mask
# print(F"Infinity mask {inf_mask}, Masked logits {masked_logits}")

6
examples/shields/rl/Wrapper.py

@ -109,7 +109,7 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
# else set everything to one
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield:
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = allowed_action[0]
@ -118,8 +118,8 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
for index, x in enumerate(mask):
mask[index] = 1.0
#print(F"Action Mask for position {coordinates} and view {view_direction} is {mask}")
#print(F"Action Mask for position {coordinates} and view {view_direction} is {mask} Position String: {cur_pos_str})")
return mask

Loading…
Cancel
Save