diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index e9ce31c..f8b5008 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -117,9 +117,9 @@ def main(): import argparse args = parse_arguments(argparse) - if args.algorithm == "ppo": + if args.algorithm == "PPO": ppo(args) - elif args.algorithm == "dqn": + elif args.algorithm == "DQN": dqn(args) diff --git a/examples/shields/rl/12_minigridrl_tune.py b/examples/shields/rl/12_minigridrl_tune.py new file mode 100644 index 0000000..57beb01 --- /dev/null +++ b/examples/shields/rl/12_minigridrl_tune.py @@ -0,0 +1,133 @@ +import gymnasium as gym +import minigrid + +from ray import tune, air +from ray.tune import register_env +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.dqn.dqn import DQNConfig +from ray.tune.logger import pretty_print +from ray.rllib.models import ModelCatalog + + +from torch_action_mask_model import TorchActionMaskModel +from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper +from helpers import parse_arguments, create_log_dir, ShieldingConfig +from shieldhandlers import MiniGridShieldHandler, create_shield_query +from callbacks import MyCallbacks + +from torch.utils.tensorboard import SummaryWriter +from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger + +def shielding_env_creater(config): + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + framestack = config.get("framestack", 4) + args = config.get("args", None) + args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + + env = gym.make(name) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) + # env = minigrid.wrappers.ImgObsWrapper(env) + # env = ImgObsWrapper(env) + env = OneHotShieldingWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + + return env + + + +def register_minigrid_shielding_env(args): + env_name = "mini-grid-shielding" + register_env(env_name, shielding_env_creater) + + ModelCatalog.register_custom_model( + "shielding_model", + TorchActionMaskModel + ) + + +def ppo(args): + register_minigrid_shielding_env(args) + + config = (PPOConfig() + .rollouts(num_rollout_workers=args.workers) + .resources(num_gpus=0) + .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) + .framework("torch") + .callbacks(MyCallbacks) + .rl_module(_enable_rl_module_api = False) + .debugging(logger_config={ + "type": TBXLogger, + "logdir": create_log_dir(args) + }) + .training(_enable_learner_api=False ,model={ + "custom_model": "shielding_model" + })) + + return config + + + +def dqn(args): + register_minigrid_shielding_env(args) + + + config = DQNConfig() + config = config.resources(num_gpus=0) + config = config.rollouts(num_rollout_workers=args.workers) + config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) + config = config.framework("torch") + config = config.callbacks(MyCallbacks) + config = config.rl_module(_enable_rl_module_api = False) + config = config.debugging(logger_config={ + "type": TBXLogger, + "logdir": create_log_dir(args) + }) + config = config.training(hiddens=[], dueling=False, model={ + "custom_model": "shielding_model" + }) + + return config + + +def main(): + import argparse + args = parse_arguments(argparse) + + if args.algorithm == "PPO": + config = ppo(args) + elif args.algorithm == "DQN": + config = dqn(args) + + logdir = create_log_dir(args) + + tuner = tune.Tuner(args.algorithm, + tune_config=tune.TuneConfig( + metric="episode_reward_mean", + mode="max", + num_samples=1, + + ), + run_config=air.RunConfig( + stop = {"episode_reward_mean": 94, + "timesteps_total": 12000, + "training_iteration": args.iterations}, + checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), + storage_path=F"{logdir}" + ), + param_space=config, + ) + + tuner.fit() + + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py index ade53d1..e7a163c 100644 --- a/examples/shields/rl/15_train_eval_tune.py +++ b/examples/shields/rl/15_train_eval_tune.py @@ -20,7 +20,7 @@ from callbacks import MyCallbacks def shielding_env_creater(config): - name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") framestack = config.get("framestack", 4) args = config.get("args", None) args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" @@ -89,7 +89,6 @@ def ppo(args): "training_iteration": args.iterations}, checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), storage_path=F"{logdir}" - #storage_path="../niceslogging/test" ) , param_space=config,) diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 690855c..6fc37bd 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -88,7 +88,7 @@ def parse_arguments(argparse): parser.add_argument("--grid_to_prism_binary_path", default="./main") parser.add_argument("--grid_path", default="grid") parser.add_argument("--prism_path", default="grid") - parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) + parser.add_argument("--algorithm", default="PPO", type=str.upper , choices=["PPO", "DQN"]) parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--iterations", type=int, default=30 ) parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"