diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index e66cfc1..7e5660d 100755 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -63,6 +63,7 @@ def register_minigrid_shielding_env(args): def ppo(args): + train_batch_size = 4000 register_minigrid_shielding_env(args) config = (PPOConfig() @@ -77,17 +78,17 @@ def ppo(args): "logdir": create_log_dir(args) }) # .exploration(exploration_config={"exploration_fraction": 0.1}) - .training(_enable_learner_api=False ,model={ - "custom_model": "shielding_model" - })) + .training(_enable_learner_api=False , + model={"custom_model": "shielding_model"}, + train_batch_size=train_batch_size)) # config.entropy_coeff = 0.05 algo =( config.build() ) - - for i in range(args.evaluations): + iterations = int((args.steps / train_batch_size)) + 1 + for i in range(iterations): result = algo.train() print(pretty_print(result)) @@ -99,6 +100,7 @@ def ppo(args): def dqn(args): + train_batch_size = 4000 register_minigrid_shielding_env(args) @@ -113,15 +115,16 @@ def dqn(args): "type": TBXLogger, "logdir": create_log_dir(args) }) - config = config.training(hiddens=[], dueling=False, model={ + config = config.training(hiddens=[], dueling=False, train_batch_size=train_batch_size, model={ "custom_model": "shielding_model" }) algo = ( config.build() ) - - for i in range(args.evaluations): + + iterations = int((args.steps / train_batch_size)) + 1 + for i in range(iterations): result = algo.train() print(pretty_print(result)) diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py index 56fa8ee..dae1c77 100644 --- a/examples/shields/rl/14_train_eval.py +++ b/examples/shields/rl/14_train_eval.py @@ -53,7 +53,7 @@ def register_minigrid_shielding_env(args): def ppo(args): register_minigrid_shielding_env(args) - + train_batch_size = 4000 config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) @@ -74,18 +74,17 @@ def ppo(args): }) .training(_enable_learner_api=False ,model={ "custom_model": "shielding_model" - })) + }, train_batch_size=train_batch_size)) algo =( config.build() ) - evaluations = args.evaluations - + iterations = int((args.steps / train_batch_size)) + 1 - for i in range(evaluations): + for i in range(iterations): algo.train() if i % 5 == 0: