Browse Source

chnaged logdir handling

refactoring
Thomas Knoll 1 year ago
parent
commit
442fff1344
  1. 38
      examples/shields/rl/15_train_eval_tune.py
  2. 4
      examples/shields/rl/helpers.py

38
examples/shields/rl/15_train_eval_tune.py

@ -13,7 +13,7 @@ from ray.rllib.algorithms.algorithm import Algorithm
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name
from shieldhandlers import MiniGridShieldHandler, create_shield_query from shieldhandlers import MiniGridShieldHandler, create_shield_query
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -54,7 +54,7 @@ def register_minigrid_shielding_env(args):
def ppo(args): def ppo(args):
register_minigrid_shielding_env(args) register_minigrid_shielding_env(args)
logdir = create_log_dir(args)
logdir = args.log_dir
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
@ -92,7 +92,9 @@ def ppo(args):
checkpoint_score_attribute="episode_reward_mean", checkpoint_score_attribute="episode_reward_mean",
), ),
storage_path=F"{logdir}"
storage_path=F"{logdir}",
name=test_name(args)
) )
, ,
param_space=config,) param_space=config,)
@ -113,26 +115,26 @@ def ppo(args):
algo = Algorithm.from_checkpoint(best_result.checkpoint) algo = Algorithm.from_checkpoint(best_result.checkpoint)
eval_log_dir = F"{logdir}-eval"
# eval_log_dir = F"{logdir}-eval"
writer = SummaryWriter(log_dir=eval_log_dir)
csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
# writer = SummaryWriter(log_dir=eval_log_dir)
# csv_logger = CSVLogger(config=config, logdir=eval_log_dir)
for i in range(args.evaluations):
eval_result = algo.evaluate()
print(pretty_print(eval_result))
print(eval_result)
# logger.on_result(eval_result)
# for i in range(args.evaluations):
# eval_result = algo.evaluate()
# print(pretty_print(eval_result))
# print(eval_result)
# # logger.on_result(eval_result)
csv_logger.on_result(eval_result)
# csv_logger.on_result(eval_result)
evaluation = eval_result['evaluation']
epsiode_reward_mean = evaluation['episode_reward_mean']
episode_len_mean = evaluation['episode_len_mean']
print(epsiode_reward_mean)
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
# evaluation = eval_result['evaluation']
# epsiode_reward_mean = evaluation['episode_reward_mean']
# episode_len_mean = evaluation['episode_len_mean']
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
# writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
def main(): def main():

4
examples/shields/rl/helpers.py

@ -39,8 +39,10 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}{args.algorithm}-shielding:{args.shielding}-evaluations:{args.evaluations}-steps:{args.steps}-env:{args.env}"
return F"{args.log_dir}sh:{args.shielding}-env:{args.env}"
def test_name(args):
return F"sh:{args.shielding}-env:{args.env}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):
for action_str in actions: for action_str in actions:

Loading…
Cancel
Save