Browse Source

added trial name

refactoring
Thomas Knoll 1 year ago
parent
commit
bcc19ec9ca
  1. 8
      examples/shields/rl/15_train_eval_tune.py

8
examples/shields/rl/15_train_eval_tune.py

@ -3,6 +3,7 @@ import minigrid
import ray import ray
from ray.tune import register_env from ray.tune import register_env
from ray.tune.experiment.trial import Trial
from ray import tune, air from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger from ray.tune.logger import UnifiedLogger
@ -51,6 +52,9 @@ def register_minigrid_shielding_env(args):
TorchActionMaskModel TorchActionMaskModel
) )
def trial_name_creator(trial : Trial):
return "trial"
def ppo(args): def ppo(args):
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
@ -83,6 +87,7 @@ def ppo(args):
metric="episode_reward_mean", metric="episode_reward_mean",
mode="max", mode="max",
num_samples=1, num_samples=1,
trial_name_creator=trial_name_creator,
), ),
run_config=air.RunConfig( run_config=air.RunConfig(
stop = {"episode_reward_mean": 94, stop = {"episode_reward_mean": 94,
@ -93,7 +98,8 @@ def ppo(args):
), ),
storage_path=F"{logdir}", storage_path=F"{logdir}",
name=test_name(args)
name=test_name(args),
) )
, ,

Loading…
Cancel
Save