|
@ -30,7 +30,6 @@ from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterat |
|
|
from ray import tune, air |
|
|
from ray import tune, air |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
|
|
from ray.rllib.algorithms.callbacks import DefaultCallbacks |
|
|
from ray.tune.logger import pretty_print |
|
|
from ray.tune.logger import pretty_print |
|
|
from ray.rllib.algorithms import ppo |
|
|
|
|
|
from ray.rllib.models import ModelCatalog |
|
|
from ray.rllib.models import ModelCatalog |
|
|
|
|
|
|
|
|
from ray.rllib.utils.torch_utils import FLOAT_MIN |
|
|
from ray.rllib.utils.torch_utils import FLOAT_MIN |
|
@ -196,9 +195,10 @@ def create_environment(args): |
|
|
return env |
|
|
return env |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_custom_minigrid_env(): |
|
|
|
|
|
|
|
|
def register_custom_minigrid_env(args): |
|
|
env_name = "mini-grid" |
|
|
env_name = "mini-grid" |
|
|
register_env(env_name, env_creater_custom) |
|
|
register_env(env_name, env_creater_custom) |
|
|
|
|
|
|
|
|
ModelCatalog.register_custom_model( |
|
|
ModelCatalog.register_custom_model( |
|
|
"pa_model", |
|
|
"pa_model", |
|
|
TorchActionMaskModel |
|
|
TorchActionMaskModel |
|
@ -216,7 +216,7 @@ def create_shield_dict(args): |
|
|
shield_dict = create_shield(grid_file, prism_path) |
|
|
shield_dict = create_shield(grid_file, prism_path) |
|
|
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} |
|
|
#shield_dict = {state.id : shield.get_choice(state).choice_map for state in model.states} |
|
|
|
|
|
|
|
|
print(F"Shield dictionary {shield_dict}") |
|
|
|
|
|
|
|
|
#print(F"Shield dictionary {shield_dict}") |
|
|
# for state_id in model.states: |
|
|
# for state_id in model.states: |
|
|
# choices = shield.get_choice(state_id) |
|
|
# choices = shield.get_choice(state_id) |
|
|
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
|
|
# print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |
|
@ -228,7 +228,7 @@ def ppo(args): |
|
|
ray.init(num_cpus=3) |
|
|
ray.init(num_cpus=3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
register_custom_minigrid_env() |
|
|
|
|
|
|
|
|
register_custom_minigrid_env(args) |
|
|
shield_dict = create_shield_dict(args) |
|
|
shield_dict = create_shield_dict(args) |
|
|
|
|
|
|
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
@ -268,30 +268,35 @@ def ppo(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dqn(args): |
|
|
def dqn(args): |
|
|
config = DQNConfig() |
|
|
|
|
|
register_custom_minigrid_env() |
|
|
|
|
|
|
|
|
register_custom_minigrid_env(args) |
|
|
shield_dict = create_shield_dict(args) |
|
|
shield_dict = create_shield_dict(args) |
|
|
replay_config = config.replay_buffer_config.update( |
|
|
|
|
|
{ |
|
|
|
|
|
"capacity": 60000, |
|
|
|
|
|
"prioritized_replay_alpha": 0.5, |
|
|
|
|
|
"prioritized_replay_beta": 0.5, |
|
|
|
|
|
"prioritized_replay_eps": 3e-6, |
|
|
|
|
|
} |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
config = config.training(replay_buffer_config=replay_config, model={ |
|
|
|
|
|
"custom_model": "pa_model", |
|
|
|
|
|
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config = DQNConfig() |
|
|
config = config.resources(num_gpus=0) |
|
|
config = config.resources(num_gpus=0) |
|
|
config = config.rollouts(num_rollout_workers=1) |
|
|
config = config.rollouts(num_rollout_workers=1) |
|
|
|
|
|
config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
config = config.framework("torch") |
|
|
config = config.framework("torch") |
|
|
config = config.callbacks(MyCallbacks) |
|
|
config = config.callbacks(MyCallbacks) |
|
|
config = config.rl_module(_enable_rl_module_api = False) |
|
|
config = config.rl_module(_enable_rl_module_api = False) |
|
|
|
|
|
config = config.training(hiddens=[], dueling=False, model={ |
|
|
|
|
|
"custom_model": "pa_model", |
|
|
|
|
|
"custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) |
|
|
|
|
|
|
|
|
algo = ( |
|
|
|
|
|
config.build() |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(30): |
|
|
|
|
|
result = algo.train() |
|
|
|
|
|
print(pretty_print(result)) |
|
|
|
|
|
|
|
|
|
|
|
if i % 5 == 0: |
|
|
|
|
|
checkpoint_dir = algo.save() |
|
|
|
|
|
print(f"Checkpoint saved in directory {checkpoint_dir}") |
|
|
|
|
|
|
|
|
|
|
|
ray.shutdown() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |