diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 187c0bd..720d11c 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -1,6 +1,7 @@ import gymnasium as gym import minigrid +import ray from ray.tune import register_env from ray import tune, air from ray.rllib.algorithms.ppo import PPOConfig @@ -82,7 +83,6 @@ def ppo(args): metric="episode_reward_mean", mode="max", num_samples=1, - ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, @@ -133,6 +133,7 @@ def ppo(args): def main(): + ray.init(num_cpus=4) import argparse args = parse_arguments(argparse)