Browse Source

added trial name

refactoring
Thomas Knoll 1 year ago
parent
commit
bcc19ec9ca
  1. 8
      examples/shields/rl/15_train_eval_tune.py

8
examples/shields/rl/15_train_eval_tune.py

@ -3,6 +3,7 @@ import minigrid
import ray
from ray.tune import register_env
from ray.tune.experiment.trial import Trial
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger
@ -50,6 +51,9 @@ def register_minigrid_shielding_env(args):
"shielding_model",
TorchActionMaskModel
)
def trial_name_creator(trial : Trial):
return "trial"
def ppo(args):
@ -83,6 +87,7 @@ def ppo(args):
metric="episode_reward_mean",
mode="max",
num_samples=1,
trial_name_creator=trial_name_creator,
),
run_config=air.RunConfig(
stop = {"episode_reward_mean": 94,
@ -93,7 +98,8 @@ def ppo(args):
),
storage_path=F"{logdir}",
name=test_name(args)
name=test_name(args),
)
,

Loading…
Cancel
Save