|
@ -13,7 +13,7 @@ from ray.rllib.algorithms.algorithm import Algorithm |
|
|
|
|
|
|
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
|
|
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
|
|
from helpers import parse_arguments, create_log_dir, ShieldingConfig |
|
|
|
|
|
|
|
|
from helpers import parse_arguments, create_log_dir, ShieldingConfig, test_name |
|
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
|
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
|
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
@ -54,7 +54,7 @@ def register_minigrid_shielding_env(args): |
|
|
|
|
|
|
|
|
def ppo(args): |
|
|
def ppo(args): |
|
|
register_minigrid_shielding_env(args) |
|
|
register_minigrid_shielding_env(args) |
|
|
logdir = create_log_dir(args) |
|
|
|
|
|
|
|
|
logdir = args.log_dir |
|
|
|
|
|
|
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
@ -92,7 +92,9 @@ def ppo(args): |
|
|
checkpoint_score_attribute="episode_reward_mean", |
|
|
checkpoint_score_attribute="episode_reward_mean", |
|
|
), |
|
|
), |
|
|
|
|
|
|
|
|
storage_path=F"{logdir}" |
|
|
|
|
|
|
|
|
storage_path=F"{logdir}", |
|
|
|
|
|
name=test_name(args) |
|
|
|
|
|
|
|
|
) |
|
|
) |
|
|
, |
|
|
, |
|
|
param_space=config,) |
|
|
param_space=config,) |
|
@ -113,26 +115,26 @@ def ppo(args): |
|
|
algo = Algorithm.from_checkpoint(best_result.checkpoint) |
|
|
algo = Algorithm.from_checkpoint(best_result.checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_log_dir = F"{logdir}-eval" |
|
|
|
|
|
|
|
|
# eval_log_dir = F"{logdir}-eval" |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=eval_log_dir) |
|
|
|
|
|
csv_logger = CSVLogger(config=config, logdir=eval_log_dir) |
|
|
|
|
|
|
|
|
# writer = SummaryWriter(log_dir=eval_log_dir) |
|
|
|
|
|
# csv_logger = CSVLogger(config=config, logdir=eval_log_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(args.evaluations): |
|
|
|
|
|
eval_result = algo.evaluate() |
|
|
|
|
|
print(pretty_print(eval_result)) |
|
|
|
|
|
print(eval_result) |
|
|
|
|
|
# logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
# for i in range(args.evaluations): |
|
|
|
|
|
# eval_result = algo.evaluate() |
|
|
|
|
|
# print(pretty_print(eval_result)) |
|
|
|
|
|
# print(eval_result) |
|
|
|
|
|
# # logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
csv_logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
# csv_logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
evaluation = eval_result['evaluation'] |
|
|
|
|
|
epsiode_reward_mean = evaluation['episode_reward_mean'] |
|
|
|
|
|
episode_len_mean = evaluation['episode_len_mean'] |
|
|
|
|
|
print(epsiode_reward_mean) |
|
|
|
|
|
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) |
|
|
|
|
|
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) |
|
|
|
|
|
|
|
|
# evaluation = eval_result['evaluation'] |
|
|
|
|
|
# epsiode_reward_mean = evaluation['episode_reward_mean'] |
|
|
|
|
|
# episode_len_mean = evaluation['episode_len_mean'] |
|
|
|
|
|
# print(epsiode_reward_mean) |
|
|
|
|
|
# writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) |
|
|
|
|
|
# writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|