|
@ -1,6 +1,7 @@ |
|
|
import gymnasium as gym |
|
|
import gymnasium as gym |
|
|
import minigrid |
|
|
import minigrid |
|
|
|
|
|
|
|
|
|
|
|
import ray |
|
|
from ray.tune import register_env |
|
|
from ray.tune import register_env |
|
|
from ray import tune, air |
|
|
from ray import tune, air |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
@ -82,7 +83,6 @@ def ppo(args): |
|
|
metric="episode_reward_mean", |
|
|
metric="episode_reward_mean", |
|
|
mode="max", |
|
|
mode="max", |
|
|
num_samples=1, |
|
|
num_samples=1, |
|
|
|
|
|
|
|
|
), |
|
|
), |
|
|
run_config=air.RunConfig( |
|
|
run_config=air.RunConfig( |
|
|
stop = {"episode_reward_mean": 94, |
|
|
stop = {"episode_reward_mean": 94, |
|
@ -133,6 +133,7 @@ def ppo(args): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|
|
|
|
ray.init(num_cpus=4) |
|
|
import argparse |
|
|
import argparse |
|
|
args = parse_arguments(argparse) |
|
|
args = parse_arguments(argparse) |
|
|
|
|
|
|
|
|