Browse Source

changed iteration handling

refactoring
Thomas Knoll 11 months ago
parent
commit
ae94b57876
  1. 17
      examples/shields/rl/11_minigridrl.py
  2. 9
      examples/shields/rl/14_train_eval.py

17
examples/shields/rl/11_minigridrl.py

@ -63,6 +63,7 @@ def register_minigrid_shielding_env(args):
def ppo(args): def ppo(args):
train_batch_size = 4000
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
config = (PPOConfig() config = (PPOConfig()
@ -77,17 +78,17 @@ def ppo(args):
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
# .exploration(exploration_config={"exploration_fraction": 0.1}) # .exploration(exploration_config={"exploration_fraction": 0.1})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
.training(_enable_learner_api=False ,
model={"custom_model": "shielding_model"},
train_batch_size=train_batch_size))
# config.entropy_coeff = 0.05 # config.entropy_coeff = 0.05
algo =( algo =(
config.build() config.build()
) )
for i in range(args.evaluations):
iterations = int((args.steps / train_batch_size)) + 1
for i in range(iterations):
result = algo.train() result = algo.train()
print(pretty_print(result)) print(pretty_print(result))
@ -99,6 +100,7 @@ def ppo(args):
def dqn(args): def dqn(args):
train_batch_size = 4000
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
@ -113,7 +115,7 @@ def dqn(args):
"type": TBXLogger, "type": TBXLogger,
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
config = config.training(hiddens=[], dueling=False, model={
config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={
"custom_model": "shielding_model" "custom_model": "shielding_model"
}) })
@ -121,7 +123,8 @@ def dqn(args):
config.build() config.build()
) )
for i in range(args.evaluations):
iterations = int((args.steps / train_batch_size)) + 1
for i in range(iterations):
result = algo.train() result = algo.train()
print(pretty_print(result)) print(pretty_print(result))

9
examples/shields/rl/14_train_eval.py

@ -53,7 +53,7 @@ def register_minigrid_shielding_env(args):
def ppo(args): def ppo(args):
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
train_batch_size = 4000
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0) .resources(num_gpus=0)
@ -74,18 +74,17 @@ def ppo(args):
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "shielding_model" "custom_model": "shielding_model"
}))
}, train_batch_size=train_batch_size))
algo =( algo =(
config.build() config.build()
) )
evaluations = args.evaluations
iterations = int((args.steps / train_batch_size)) + 1
for i in range(evaluations):
for i in range(iterations):
algo.train() algo.train()
if i % 5 == 0: if i % 5 == 0:

Loading…
Cancel
Save