4 changed files with 137 additions and 5 deletions
			
			
		- 
					4examples/shields/rl/11_minigridrl.py
- 
					133examples/shields/rl/12_minigridrl_tune.py
- 
					3examples/shields/rl/15_train_eval_tune.py
- 
					2examples/shields/rl/helpers.py
| @ -0,0 +1,133 @@ | |||
| import gymnasium as gym | |||
| import minigrid | |||
| 
 | |||
| from ray import tune, air | |||
| from ray.tune import register_env | |||
| from ray.rllib.algorithms.algorithm import Algorithm | |||
| from ray.rllib.algorithms.ppo import PPOConfig | |||
| from ray.rllib.algorithms.dqn.dqn import DQNConfig | |||
| from ray.tune.logger import pretty_print | |||
| from ray.rllib.models import ModelCatalog | |||
| 
 | |||
| 
 | |||
| from torch_action_mask_model import TorchActionMaskModel | |||
| from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | |||
| from helpers import parse_arguments, create_log_dir, ShieldingConfig | |||
| from shieldhandlers import MiniGridShieldHandler, create_shield_query | |||
| from callbacks import MyCallbacks | |||
| 
 | |||
| from torch.utils.tensorboard import SummaryWriter | |||
| from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger | |||
| 
 | |||
| def shielding_env_creater(config): | |||
|     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | |||
|     framestack = config.get("framestack", 4) | |||
|     args = config.get("args", None) | |||
|     args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" | |||
|     args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" | |||
|      | |||
|     shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | |||
|      | |||
|     env = gym.make(name) | |||
|     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) | |||
|     # env = minigrid.wrappers.ImgObsWrapper(env) | |||
|     # env = ImgObsWrapper(env) | |||
|     env = OneHotShieldingWrapper(env, | |||
|                         config.vector_index if hasattr(config, "vector_index") else 0, | |||
|                         framestack=framestack | |||
|                         ) | |||
|      | |||
|      | |||
|     return env | |||
| 
 | |||
| 
 | |||
| 
 | |||
| def register_minigrid_shielding_env(args): | |||
|     env_name = "mini-grid-shielding" | |||
|     register_env(env_name, shielding_env_creater) | |||
| 
 | |||
|     ModelCatalog.register_custom_model( | |||
|         "shielding_model",  | |||
|         TorchActionMaskModel | |||
|     ) | |||
| 
 | |||
| 
 | |||
| def ppo(args): | |||
|     register_minigrid_shielding_env(args) | |||
|      | |||
|     config = (PPOConfig() | |||
|         .rollouts(num_rollout_workers=args.workers) | |||
|         .resources(num_gpus=0) | |||
|         .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) | |||
|         .framework("torch") | |||
|         .callbacks(MyCallbacks) | |||
|         .rl_module(_enable_rl_module_api = False) | |||
|         .debugging(logger_config={ | |||
|             "type": TBXLogger,  | |||
|             "logdir": create_log_dir(args) | |||
|         }) | |||
|         .training(_enable_learner_api=False ,model={ | |||
|             "custom_model": "shielding_model" | |||
|         })) | |||
|      | |||
|     return config | |||
|      | |||
|              | |||
| 
 | |||
| def dqn(args): | |||
|     register_minigrid_shielding_env(args) | |||
| 
 | |||
|      | |||
|     config = DQNConfig() | |||
|     config = config.resources(num_gpus=0) | |||
|     config = config.rollouts(num_rollout_workers=args.workers) | |||
|     config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) | |||
|     config = config.framework("torch") | |||
|     config = config.callbacks(MyCallbacks) | |||
|     config = config.rl_module(_enable_rl_module_api = False) | |||
|     config = config.debugging(logger_config={ | |||
|             "type": TBXLogger,  | |||
|             "logdir": create_log_dir(args) | |||
|         }) | |||
|     config = config.training(hiddens=[], dueling=False, model={     | |||
|             "custom_model": "shielding_model" | |||
|     }) | |||
|      | |||
|     return config | |||
|              | |||
| 
 | |||
| def main(): | |||
|     import argparse | |||
|     args = parse_arguments(argparse) | |||
| 
 | |||
|     if args.algorithm == "PPO": | |||
|         config = ppo(args) | |||
|     elif args.algorithm == "DQN": | |||
|         config = dqn(args) | |||
|          | |||
|     logdir = create_log_dir(args) | |||
|          | |||
|     tuner = tune.Tuner(args.algorithm, | |||
|                         tune_config=tune.TuneConfig( | |||
|                             metric="episode_reward_mean", | |||
|                             mode="max", | |||
|                             num_samples=1, | |||
|                              | |||
|                         ), | |||
|                         run_config=air.RunConfig( | |||
|                                 stop = {"episode_reward_mean": 94, | |||
|                                         "timesteps_total": 12000, | |||
|                                         "training_iteration": args.iterations},  | |||
|                                 checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), | |||
|                                 storage_path=F"{logdir}" | |||
|                         ), | |||
|                         param_space=config, | |||
|                     ) | |||
| 
 | |||
|     tuner.fit() | |||
|   | |||
|     | |||
| 
 | |||
| 
 | |||
| if __name__ == '__main__': | |||
|     main() | |||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue