From 41f94bf92e017459c3884e26a6b0bcf77da9f965 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Fri, 15 Sep 2023 09:51:33 +0200 Subject: [PATCH] added expname to grid path --- examples/shields/rl/15_train_eval_tune.py | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index 88be821..ba8c8e2 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -10,7 +10,7 @@ from ray.tune.logger import UnifiedLogger from ray.rllib.models import ModelCatalog from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger from ray.rllib.algorithms.algorithm import Algorithm - +from ray.air import session from torch_action_mask_model import TorchActionMaskModel from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper @@ -25,11 +25,13 @@ def shielding_env_creater(config): name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") framestack = config.get("framestack", 4) args = config.get("args", None) - args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" - args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" - + args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism" shielding = config.get("shielding", False) - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + shield_creator = MiniGridShieldHandler(grid_file=args.grid_path, + grid_to_prism_path=args.grid_to_prism_binary_path, + prism_path=args.prism_path, + formula=args.formula) env = gym.make(name) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) @@ -64,7 +66,10 @@ def ppo(args): .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) .environment( env="mini-grid-shielding", - env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) + env_config={"name": args.env, + "args": args, + "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, + },) .framework("torch") .callbacks(MyCallbacks) .evaluation(evaluation_config={ @@ -72,7 +77,9 @@ def ppo(args): "evaluation_duration": 10, "evaluation_num_workers":1, "env": "mini-grid-shielding", - "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) + "env_config": {"name": args.env, + "args": args, + "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": UnifiedLogger, @@ -87,7 +94,8 @@ def ppo(args): metric="episode_reward_mean", mode="max", num_samples=1, - trial_name_creator=trial_name_creator, + trial_name_creator=trial_name_creator, + ), run_config=air.RunConfig( stop = {"episode_reward_mean": 94, @@ -144,7 +152,7 @@ def ppo(args): def main(): - ray.init(num_cpus=4) + ray.init(num_cpus=3) import argparse args = parse_arguments(argparse)