From 717c644aada72e98346b7e02a8f20d1c5f7425d7 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Fri, 8 Sep 2023 14:49:20 +0200 Subject: [PATCH] changed ray tune example --- examples/shields/rl/15_train_eval_tune.py | 72 ++++++++++++++++++----- examples/shields/rl/helpers.py | 2 +- 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index aa64bc2..ade53d1 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -6,6 +6,8 @@ from ray import tune, air from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.rllib.models import ModelCatalog +from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger +from ray.rllib.algorithms.algorithm import Algorithm from torch_action_mask_model import TorchActionMaskModel @@ -13,6 +15,7 @@ from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from helpers import parse_arguments, create_log_dir, ShieldingConfig from shieldhandlers import MiniGridShieldHandler, create_shield_query +from torch.utils.tensorboard import SummaryWriter from callbacks import MyCallbacks @@ -23,11 +26,7 @@ def shielding_env_creater(config): args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" - shielding = config.get("shielding", False) - - # if shielding: - # assert(False) - + shielding = config.get("shielding", False) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) env = gym.make(name) @@ -54,6 +53,7 @@ def register_minigrid_shielding_env(args): def ppo(args): register_minigrid_shielding_env(args) + logdir = create_log_dir(args) config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) @@ -71,26 +71,66 @@ def ppo(args): .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": UnifiedLogger, - "logdir": create_log_dir(args) + "logdir": logdir }) .training(_enable_learner_api=False ,model={ "custom_model": "shielding_model" })) tuner = tune.Tuner("PPO", + tune_config=tune.TuneConfig( + metric="episode_reward_mean", + mode="max", + num_samples=1, + + ), run_config=air.RunConfig( - stop = {"episode_reward_mean": 50}, - checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True), - storage_path=F"{create_log_dir(args)}-tuner" - ), + stop = {"episode_reward_mean": 94, + "training_iteration": args.iterations}, + checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), + storage_path=F"{logdir}" + #storage_path="../niceslogging/test" + ) + , param_space=config,) - tuner.fit() - - - # print(epsiode_reward_mean) - # writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i) - + results = tuner.fit() + best_result = results.get_best_result() + + import pprint + + metrics_to_print = [ + "episode_reward_mean", + "episode_reward_max", + "episode_reward_min", + "episode_len_mean", +] + pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print}) + + algo = Algorithm.from_checkpoint(best_result.checkpoint) + + + eval_log_dir = F"{logdir}-eval" + + writer = SummaryWriter(log_dir=eval_log_dir) + csv_logger = CSVLogger(config=config, logdir=eval_log_dir) + + + for i in range(args.iterations): + eval_result = algo.evaluate() + print(pretty_print(eval_result)) + print(eval_result) + # logger.on_result(eval_result) + + csv_logger.on_result(eval_result) + + evaluation = eval_result['evaluation'] + epsiode_reward_mean = evaluation['episode_reward_mean'] + episode_len_mean = evaluation['episode_len_mean'] + print(epsiode_reward_mean) + writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) + writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) + def main(): import argparse diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index f906abf..690855c 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -39,7 +39,7 @@ def extract_keys(env): return keys def create_log_dir(args): - return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}-iterations:{args.iterations}" + return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-iterations:{args.iterations}" def get_action_index_mapping(actions):