|
@ -6,6 +6,8 @@ from ray import tune, air |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
from ray.rllib.algorithms.ppo import PPOConfig |
|
|
from ray.tune.logger import UnifiedLogger |
|
|
from ray.tune.logger import UnifiedLogger |
|
|
from ray.rllib.models import ModelCatalog |
|
|
from ray.rllib.models import ModelCatalog |
|
|
|
|
|
from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger |
|
|
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
|
from torch_action_mask_model import TorchActionMaskModel |
|
@ -13,6 +15,7 @@ from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
|
|
from helpers import parse_arguments, create_log_dir, ShieldingConfig |
|
|
from helpers import parse_arguments, create_log_dir, ShieldingConfig |
|
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
|
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
|
|
|
|
|
|
|
|
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from callbacks import MyCallbacks |
|
|
from callbacks import MyCallbacks |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -24,10 +27,6 @@ def shielding_env_creater(config): |
|
|
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" |
|
|
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" |
|
|
|
|
|
|
|
|
shielding = config.get("shielding", False) |
|
|
shielding = config.get("shielding", False) |
|
|
|
|
|
|
|
|
# if shielding: |
|
|
|
|
|
# assert(False) |
|
|
|
|
|
|
|
|
|
|
|
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |
|
|
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |
|
|
|
|
|
|
|
|
env = gym.make(name) |
|
|
env = gym.make(name) |
|
@ -54,6 +53,7 @@ def register_minigrid_shielding_env(args): |
|
|
|
|
|
|
|
|
def ppo(args): |
|
|
def ppo(args): |
|
|
register_minigrid_shielding_env(args) |
|
|
register_minigrid_shielding_env(args) |
|
|
|
|
|
logdir = create_log_dir(args) |
|
|
|
|
|
|
|
|
config = (PPOConfig() |
|
|
config = (PPOConfig() |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
|
.rollouts(num_rollout_workers=args.workers) |
|
@ -71,25 +71,65 @@ def ppo(args): |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.rl_module(_enable_rl_module_api = False) |
|
|
.debugging(logger_config={ |
|
|
.debugging(logger_config={ |
|
|
"type": UnifiedLogger, |
|
|
"type": UnifiedLogger, |
|
|
"logdir": create_log_dir(args) |
|
|
|
|
|
|
|
|
"logdir": logdir |
|
|
}) |
|
|
}) |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
.training(_enable_learner_api=False ,model={ |
|
|
"custom_model": "shielding_model" |
|
|
"custom_model": "shielding_model" |
|
|
})) |
|
|
})) |
|
|
|
|
|
|
|
|
tuner = tune.Tuner("PPO", |
|
|
tuner = tune.Tuner("PPO", |
|
|
|
|
|
tune_config=tune.TuneConfig( |
|
|
|
|
|
metric="episode_reward_mean", |
|
|
|
|
|
mode="max", |
|
|
|
|
|
num_samples=1, |
|
|
|
|
|
|
|
|
|
|
|
), |
|
|
run_config=air.RunConfig( |
|
|
run_config=air.RunConfig( |
|
|
stop = {"episode_reward_mean": 50}, |
|
|
|
|
|
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True), |
|
|
|
|
|
storage_path=F"{create_log_dir(args)}-tuner" |
|
|
|
|
|
), |
|
|
|
|
|
|
|
|
stop = {"episode_reward_mean": 94, |
|
|
|
|
|
"training_iteration": args.iterations}, |
|
|
|
|
|
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), |
|
|
|
|
|
storage_path=F"{logdir}" |
|
|
|
|
|
#storage_path="../niceslogging/test" |
|
|
|
|
|
) |
|
|
|
|
|
, |
|
|
param_space=config,) |
|
|
param_space=config,) |
|
|
|
|
|
|
|
|
tuner.fit() |
|
|
|
|
|
|
|
|
results = tuner.fit() |
|
|
|
|
|
best_result = results.get_best_result() |
|
|
|
|
|
|
|
|
|
|
|
import pprint |
|
|
|
|
|
|
|
|
|
|
|
metrics_to_print = [ |
|
|
|
|
|
"episode_reward_mean", |
|
|
|
|
|
"episode_reward_max", |
|
|
|
|
|
"episode_reward_min", |
|
|
|
|
|
"episode_len_mean", |
|
|
|
|
|
] |
|
|
|
|
|
pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print}) |
|
|
|
|
|
|
|
|
|
|
|
algo = Algorithm.from_checkpoint(best_result.checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_log_dir = F"{logdir}-eval" |
|
|
|
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=eval_log_dir) |
|
|
|
|
|
csv_logger = CSVLogger(config=config, logdir=eval_log_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(args.iterations): |
|
|
|
|
|
eval_result = algo.evaluate() |
|
|
|
|
|
print(pretty_print(eval_result)) |
|
|
|
|
|
print(eval_result) |
|
|
|
|
|
# logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
|
|
|
csv_logger.on_result(eval_result) |
|
|
|
|
|
|
|
|
# print(epsiode_reward_mean) |
|
|
|
|
|
# writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i) |
|
|
|
|
|
|
|
|
evaluation = eval_result['evaluation'] |
|
|
|
|
|
epsiode_reward_mean = evaluation['episode_reward_mean'] |
|
|
|
|
|
episode_len_mean = evaluation['episode_len_mean'] |
|
|
|
|
|
print(epsiode_reward_mean) |
|
|
|
|
|
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) |
|
|
|
|
|
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
def main(): |
|
|