Browse Source

added steps argument and stop criteria

refactoring
Thomas Knoll 1 year ago
parent
commit
f0df936716
  1. 1
      examples/shields/rl/15_train_eval_tune.py
  2. 4
      examples/shields/rl/helpers.py

1
examples/shields/rl/15_train_eval_tune.py

@ -86,6 +86,7 @@ def ppo(args):
), ),
run_config=air.RunConfig( run_config=air.RunConfig(
stop = {"episode_reward_mean": 94, stop = {"episode_reward_mean": 94,
"timesteps_total": args.steps,
"training_iteration": args.iterations}, "training_iteration": args.iterations},
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
storage_path=F"{logdir}" storage_path=F"{logdir}"

4
examples/shields/rl/helpers.py

@ -90,11 +90,11 @@ def parse_arguments(argparse):
parser.add_argument("--prism_path", default="grid") parser.add_argument("--prism_path", default="grid")
parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"])
parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--iterations", type=int, default=10 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1) parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
parser.add_argument("--steps", default=20_000, type=int)
args = parser.parse_args() args = parser.parse_args()

Loading…
Cancel
Save