diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 30d356c..1cafca0 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -3,6 +3,7 @@ import minigrid import ray from ray.tune import register_env +from ray.tune.experiment.trial import Trial from ray import tune, air from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger @@ -50,6 +51,9 @@ def register_minigrid_shielding_env(args): "shielding_model", TorchActionMaskModel ) + +def trial_name_creator(trial : Trial): + return "trial" def ppo(args): @@ -83,6 +87,7 @@ def ppo(args): metric="episode_reward_mean", mode="max", num_samples=1, + trial_name_creator=trial_name_creator, ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, @@ -93,7 +98,8 @@ def ppo(args): ), storage_path=F"{logdir}", - name=test_name(args) + name=test_name(args), + ) ,