Browse Source

changed ray tune example

refactoring
Thomas Knoll 1 year ago
parent
commit
717c644aad
  1. 64
      examples/shields/rl/15_train_eval_tune.py
  2. 2
      examples/shields/rl/helpers.py

64
examples/shields/rl/15_train_eval_tune.py

@ -6,6 +6,8 @@ from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import UnifiedLogger from ray.tune.logger import UnifiedLogger
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger
from ray.rllib.algorithms.algorithm import Algorithm
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
@ -13,6 +15,7 @@ from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query from shieldhandlers import MiniGridShieldHandler, create_shield_query
from torch.utils.tensorboard import SummaryWriter
from callbacks import MyCallbacks from callbacks import MyCallbacks
@ -24,10 +27,6 @@ def shielding_env_creater(config):
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False) shielding = config.get("shielding", False)
# if shielding:
# assert(False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name)
@ -54,6 +53,7 @@ def register_minigrid_shielding_env(args):
def ppo(args): def ppo(args):
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
logdir = create_log_dir(args)
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
@ -71,25 +71,65 @@ def ppo(args):
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": UnifiedLogger, "type": UnifiedLogger,
"logdir": create_log_dir(args)
"logdir": logdir
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "shielding_model" "custom_model": "shielding_model"
})) }))
tuner = tune.Tuner("PPO", tuner = tune.Tuner("PPO",
tune_config=tune.TuneConfig(
metric="episode_reward_mean",
mode="max",
num_samples=1,
),
run_config=air.RunConfig( run_config=air.RunConfig(
stop = {"episode_reward_mean": 50},
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
storage_path=F"{create_log_dir(args)}-tuner"
),
stop = {"episode_reward_mean": 94,
"training_iteration": args.iterations},
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
storage_path=F"{logdir}"
#storage_path="../niceslogging/test"
)
,
param_space=config,) param_space=config,)
tuner.fit()
results = tuner.fit()
best_result = results.get_best_result()
import pprint
metrics_to_print = [
"episode_reward_mean",
"episode_reward_max",
"episode_reward_min",
"episode_len_mean",
]
pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})
algo = Algorithm.from_checkpoint(best_result.checkpoint)
eval_log_dir = F"{logdir}-eval"
writer = SummaryWriter(log_dir=eval_log_dir)
csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
for i in range(args.iterations):
eval_result = algo.evaluate()
print(pretty_print(eval_result))
print(eval_result)
# logger.on_result(eval_result)
csv_logger.on_result(eval_result)
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i)
evaluation = eval_result['evaluation']
epsiode_reward_mean = evaluation['episode_reward_mean']
episode_len_mean = evaluation['episode_len_mean']
print(epsiode_reward_mean)
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
def main(): def main():

2
examples/shields/rl/helpers.py

@ -39,7 +39,7 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}-iterations:{args.iterations}"
return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-iterations:{args.iterations}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):

Loading…
Cancel
Save