From f0df936716d9e9b2458a2fc066c3f9813f9ad77e Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Mon, 11 Sep 2023 10:05:54 +0200 Subject: [PATCH] added steps argument and stop criteria --- examples/shields/rl/15_train_eval_tune.py | 1 + examples/shields/rl/helpers.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index e7a163c..187c0bd 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -86,6 +86,7 @@ def ppo(args): ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, + "timesteps_total": args.steps, "training_iteration": args.iterations}, checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), storage_path=F"{logdir}" diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 6fc37bd..3e67c81 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -90,11 +90,11 @@ def parse_arguments(argparse): parser.add_argument("--prism_path", default="grid") parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--log_dir", default="../log_results/") - parser.add_argument("--iterations", type=int, default=30 ) + parser.add_argument("--iterations", type=int, default=10 ) parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--workers", type=int, default=1) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) - + parser.add_argument("--steps", default=20_000, type=int) args = parser.parse_args()