|
@ -63,6 +63,7 @@ def register_minigrid_shielding_env(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ppo(args): |
|
|
def ppo(args): |
|
|
|
|
|
train_batch_size = 4000 |
|
|
register_minigrid_shielding_env(args) |
|
|
register_minigrid_shielding_env(args) |
|
|
|
|
|
|
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
@ -77,17 +78,17 @@ def ppo(args): |
|
|
"logdir": create_log_dir(args) |
|
|
"logdir": create_log_dir(args) |
|
|
}) |
|
|
}) |
|
|
# .exploration(exploration_config={"exploration_fraction": 0.1}) |
|
|
# .exploration(exploration_config={"exploration_fraction": 0.1}) |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
|
|
|
"custom_model": "shielding_model" |
|
|
|
|
|
})) |
|
|
|
|
|
|
|
|
.training(_enable_learner_api=False , |
|
|
|
|
|
model={"custom_model": "shielding_model"}, |
|
|
|
|
|
train_batch_size=train_batch_size)) |
|
|
# config.entropy_coeff = 0.05 |
|
|
# config.entropy_coeff = 0.05 |
|
|
algo =( |
|
|
algo =( |
|
|
config.build() |
|
|
config.build() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(args.evaluations): |
|
|
|
|
|
|
|
|
iterations = int((args.steps / train_batch_size)) + 1 |
|
|
|
|
|
for i in range(iterations): |
|
|
result = algo.train() |
|
|
result = algo.train() |
|
|
print(pretty_print(result)) |
|
|
print(pretty_print(result)) |
|
|
|
|
|
|
|
@ -99,6 +100,7 @@ def ppo(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dqn(args): |
|
|
def dqn(args): |
|
|
|
|
|
train_batch_size = 4000 |
|
|
register_minigrid_shielding_env(args) |
|
|
register_minigrid_shielding_env(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -113,7 +115,7 @@ def dqn(args): |
|
|
"type": TBXLogger, |
|
|
"type": TBXLogger, |
|
|
"logdir": create_log_dir(args) |
|
|
"logdir": create_log_dir(args) |
|
|
}) |
|
|
}) |
|
|
config = config.training(hiddens=[], dueling=False, model={ |
|
|
|
|
|
|
|
|
config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={ |
|
|
"custom_model": "shielding_model" |
|
|
"custom_model": "shielding_model" |
|
|
}) |
|
|
}) |
|
|
|
|
|
|
|
@ -121,7 +123,8 @@ def dqn(args): |
|
|
config.build() |
|
|
config.build() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
for i in range(args.evaluations): |
|
|
|
|
|
|
|
|
iterations = int((args.steps / train_batch_size)) + 1 |
|
|
|
|
|
for i in range(iterations): |
|
|
result = algo.train() |
|
|
result = algo.train() |
|
|
print(pretty_print(result)) |
|
|
print(pretty_print(result)) |
|
|
|
|
|
|
|
|