4 changed files with 137 additions and 5 deletions
			
			
		- 
					4examples/shields/rl/11_minigridrl.py
 - 
					133examples/shields/rl/12_minigridrl_tune.py
 - 
					3examples/shields/rl/15_train_eval_tune.py
 - 
					2examples/shields/rl/helpers.py
 
@ -0,0 +1,133 @@ | 
				
			|||
import gymnasium as gym | 
				
			|||
import minigrid | 
				
			|||
 | 
				
			|||
from ray import tune, air | 
				
			|||
from ray.tune import register_env | 
				
			|||
from ray.rllib.algorithms.algorithm import Algorithm | 
				
			|||
from ray.rllib.algorithms.ppo import PPOConfig | 
				
			|||
from ray.rllib.algorithms.dqn.dqn import DQNConfig | 
				
			|||
from ray.tune.logger import pretty_print | 
				
			|||
from ray.rllib.models import ModelCatalog | 
				
			|||
 | 
				
			|||
 | 
				
			|||
from torch_action_mask_model import TorchActionMaskModel | 
				
			|||
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
				
			|||
from helpers import parse_arguments, create_log_dir, ShieldingConfig | 
				
			|||
from shieldhandlers import MiniGridShieldHandler, create_shield_query | 
				
			|||
from callbacks import MyCallbacks | 
				
			|||
 | 
				
			|||
from torch.utils.tensorboard import SummaryWriter | 
				
			|||
from ray.tune.logger import TBXLogger, UnifiedLogger, CSVLogger | 
				
			|||
 | 
				
			|||
def shielding_env_creater(config): | 
				
			|||
    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | 
				
			|||
    framestack = config.get("framestack", 4) | 
				
			|||
    args = config.get("args", None) | 
				
			|||
    args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" | 
				
			|||
    args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" | 
				
			|||
     | 
				
			|||
    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | 
				
			|||
     | 
				
			|||
    env = gym.make(name) | 
				
			|||
    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) | 
				
			|||
    # env = minigrid.wrappers.ImgObsWrapper(env) | 
				
			|||
    # env = ImgObsWrapper(env) | 
				
			|||
    env = OneHotShieldingWrapper(env, | 
				
			|||
                        config.vector_index if hasattr(config, "vector_index") else 0, | 
				
			|||
                        framestack=framestack | 
				
			|||
                        ) | 
				
			|||
     | 
				
			|||
     | 
				
			|||
    return env | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def register_minigrid_shielding_env(args): | 
				
			|||
    env_name = "mini-grid-shielding" | 
				
			|||
    register_env(env_name, shielding_env_creater) | 
				
			|||
 | 
				
			|||
    ModelCatalog.register_custom_model( | 
				
			|||
        "shielding_model",  | 
				
			|||
        TorchActionMaskModel | 
				
			|||
    ) | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def ppo(args): | 
				
			|||
    register_minigrid_shielding_env(args) | 
				
			|||
     | 
				
			|||
    config = (PPOConfig() | 
				
			|||
        .rollouts(num_rollout_workers=args.workers) | 
				
			|||
        .resources(num_gpus=0) | 
				
			|||
        .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) | 
				
			|||
        .framework("torch") | 
				
			|||
        .callbacks(MyCallbacks) | 
				
			|||
        .rl_module(_enable_rl_module_api = False) | 
				
			|||
        .debugging(logger_config={ | 
				
			|||
            "type": TBXLogger,  | 
				
			|||
            "logdir": create_log_dir(args) | 
				
			|||
        }) | 
				
			|||
        .training(_enable_learner_api=False ,model={ | 
				
			|||
            "custom_model": "shielding_model" | 
				
			|||
        })) | 
				
			|||
     | 
				
			|||
    return config | 
				
			|||
     | 
				
			|||
             | 
				
			|||
 | 
				
			|||
def dqn(args): | 
				
			|||
    register_minigrid_shielding_env(args) | 
				
			|||
 | 
				
			|||
     | 
				
			|||
    config = DQNConfig() | 
				
			|||
    config = config.resources(num_gpus=0) | 
				
			|||
    config = config.rollouts(num_rollout_workers=args.workers) | 
				
			|||
    config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) | 
				
			|||
    config = config.framework("torch") | 
				
			|||
    config = config.callbacks(MyCallbacks) | 
				
			|||
    config = config.rl_module(_enable_rl_module_api = False) | 
				
			|||
    config = config.debugging(logger_config={ | 
				
			|||
            "type": TBXLogger,  | 
				
			|||
            "logdir": create_log_dir(args) | 
				
			|||
        }) | 
				
			|||
    config = config.training(hiddens=[], dueling=False, model={     | 
				
			|||
            "custom_model": "shielding_model" | 
				
			|||
    }) | 
				
			|||
     | 
				
			|||
    return config | 
				
			|||
             | 
				
			|||
 | 
				
			|||
def main(): | 
				
			|||
    import argparse | 
				
			|||
    args = parse_arguments(argparse) | 
				
			|||
 | 
				
			|||
    if args.algorithm == "PPO": | 
				
			|||
        config = ppo(args) | 
				
			|||
    elif args.algorithm == "DQN": | 
				
			|||
        config = dqn(args) | 
				
			|||
         | 
				
			|||
    logdir = create_log_dir(args) | 
				
			|||
         | 
				
			|||
    tuner = tune.Tuner(args.algorithm, | 
				
			|||
                        tune_config=tune.TuneConfig( | 
				
			|||
                            metric="episode_reward_mean", | 
				
			|||
                            mode="max", | 
				
			|||
                            num_samples=1, | 
				
			|||
                             | 
				
			|||
                        ), | 
				
			|||
                        run_config=air.RunConfig( | 
				
			|||
                                stop = {"episode_reward_mean": 94, | 
				
			|||
                                        "timesteps_total": 12000, | 
				
			|||
                                        "training_iteration": args.iterations},  | 
				
			|||
                                checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ), | 
				
			|||
                                storage_path=F"{logdir}" | 
				
			|||
                        ), | 
				
			|||
                        param_space=config, | 
				
			|||
                    ) | 
				
			|||
 | 
				
			|||
    tuner.fit() | 
				
			|||
  | 
				
			|||
    | 
				
			|||
 | 
				
			|||
 | 
				
			|||
if __name__ == '__main__': | 
				
			|||
    main() | 
				
			|||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue