|
|
@ -3,6 +3,7 @@ import minigrid |
|
|
|
|
|
|
|
import ray |
|
|
|
from ray.tune import register_env |
|
|
|
from ray.tune.experiment.trial import Trial |
|
|
|
from ray import tune, air |
|
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
|
from ray.tune.logger import UnifiedLogger |
|
|
@ -50,6 +51,9 @@ def register_minigrid_shielding_env(args): |
|
|
|
"shielding_model", |
|
|
|
TorchActionMaskModel |
|
|
|
) |
|
|
|
|
|
|
|
def trial_name_creator(trial : Trial): |
|
|
|
return "trial" |
|
|
|
|
|
|
|
|
|
|
|
def ppo(args): |
|
|
@ -83,6 +87,7 @@ def ppo(args): |
|
|
|
metric="episode_reward_mean", |
|
|
|
mode="max", |
|
|
|
num_samples=1, |
|
|
|
trial_name_creator=trial_name_creator, |
|
|
|
), |
|
|
|
run_config=air.RunConfig( |
|
|
|
stop = {"episode_reward_mean": 94, |
|
|
@ -93,7 +98,8 @@ def ppo(args): |
|
|
|
), |
|
|
|
|
|
|
|
storage_path=F"{logdir}", |
|
|
|
name=test_name(args) |
|
|
|
name=test_name(args), |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
, |