|  | @ -10,7 +10,7 @@ from ray.tune.logger import UnifiedLogger | 
		
	
		
			
				|  |  | from ray.rllib.models import ModelCatalog |  |  | from ray.rllib.models import ModelCatalog | 
		
	
		
			
				|  |  | from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger |  |  | from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger | 
		
	
		
			
				|  |  | from ray.rllib.algorithms.algorithm import Algorithm |  |  | from ray.rllib.algorithms.algorithm import Algorithm | 
		
	
		
			
				|  |  | 
 |  |  |  | 
		
	
		
			
				|  |  |  |  |  | from ray.air import session | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | from torch_action_mask_model import TorchActionMaskModel |  |  | from torch_action_mask_model import TorchActionMaskModel | 
		
	
		
			
				|  |  | from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |  |  | from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
		
	
	
		
			
				|  | @ -25,11 +25,13 @@ def shielding_env_creater(config): | 
		
	
		
			
				|  |  |     name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") |  |  |     name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") | 
		
	
		
			
				|  |  |     framestack = config.get("framestack", 4) |  |  |     framestack = config.get("framestack", 4) | 
		
	
		
			
				|  |  |     args = config.get("args", None) |  |  |     args = config.get("args", None) | 
		
	
		
			
				|  |  |     args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" |  |  |  | 
		
	
		
			
				|  |  |     args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" | 
		
	
		
			
				|  |  |  |  |  |     args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"    | 
		
	
		
			
				|  |  |     shielding = config.get("shielding", False)    |  |  |     shielding = config.get("shielding", False)    | 
		
	
		
			
				|  |  |     shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,  | 
		
	
		
			
				|  |  |  |  |  |                                            grid_to_prism_path=args.grid_to_prism_binary_path, | 
		
	
		
			
				|  |  |  |  |  |                                            prism_path=args.prism_path, | 
		
	
		
			
				|  |  |  |  |  |                                            formula=args.formula) | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
		
			
				|  |  |     env = gym.make(name) |  |  |     env = gym.make(name) | 
		
	
		
			
				|  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) | 
		
	
	
		
			
				|  | @ -64,7 +66,10 @@ def ppo(args): | 
		
	
		
			
				|  |  |         .rollouts(num_rollout_workers=args.workers) |  |  |         .rollouts(num_rollout_workers=args.workers) | 
		
	
		
			
				|  |  |         .resources(num_gpus=0) |  |  |         .resources(num_gpus=0) | 
		
	
		
			
				|  |  |         .environment( env="mini-grid-shielding", |  |  |         .environment( env="mini-grid-shielding", | 
		
	
		
			
				|  |  |                       env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |                       env_config={"name": args.env, | 
		
	
		
			
				|  |  |  |  |  |                                   "args": args,                                   | 
		
	
		
			
				|  |  |  |  |  |                                   "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
		
	
		
			
				|  |  |  |  |  |                                   },) | 
		
	
		
			
				|  |  |         .framework("torch") |  |  |         .framework("torch") | 
		
	
		
			
				|  |  |         .callbacks(MyCallbacks) |  |  |         .callbacks(MyCallbacks) | 
		
	
		
			
				|  |  |         .evaluation(evaluation_config={  |  |  |         .evaluation(evaluation_config={  | 
		
	
	
		
			
				|  | @ -72,7 +77,9 @@ def ppo(args): | 
		
	
		
			
				|  |  |                                         "evaluation_duration": 10, |  |  |                                         "evaluation_duration": 10, | 
		
	
		
			
				|  |  |                                         "evaluation_num_workers":1, |  |  |                                         "evaluation_num_workers":1, | 
		
	
		
			
				|  |  |                                         "env": "mini-grid-shielding",  |  |  |                                         "env": "mini-grid-shielding",  | 
		
	
		
			
				|  |  |                                         "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         |  |  |  | 
		
	
		
			
				|  |  |  |  |  |                                         "env_config": {"name": args.env,  | 
		
	
		
			
				|  |  |  |  |  |                                                        "args": args,  | 
		
	
		
			
				|  |  |  |  |  |                                                        "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         | 
		
	
		
			
				|  |  |         .rl_module(_enable_rl_module_api = False) |  |  |         .rl_module(_enable_rl_module_api = False) | 
		
	
		
			
				|  |  |         .debugging(logger_config={ |  |  |         .debugging(logger_config={ | 
		
	
		
			
				|  |  |             "type": UnifiedLogger,  |  |  |             "type": UnifiedLogger,  | 
		
	
	
		
			
				|  | @ -87,7 +94,8 @@ def ppo(args): | 
		
	
		
			
				|  |  |                            metric="episode_reward_mean", |  |  |                            metric="episode_reward_mean", | 
		
	
		
			
				|  |  |                            mode="max", |  |  |                            mode="max", | 
		
	
		
			
				|  |  |                            num_samples=1, |  |  |                            num_samples=1, | 
		
	
		
			
				|  |  |                            trial_name_creator=trial_name_creator,                            |  |  |  | 
		
	
		
			
				|  |  |  |  |  |                            trial_name_creator=trial_name_creator, | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |                        ), |  |  |                        ), | 
		
	
		
			
				|  |  |                         run_config=air.RunConfig( |  |  |                         run_config=air.RunConfig( | 
		
	
		
			
				|  |  |                                 stop = {"episode_reward_mean": 94, |  |  |                                 stop = {"episode_reward_mean": 94, | 
		
	
	
		
			
				|  | @ -144,7 +152,7 @@ def ppo(args): | 
		
	
		
			
				|  |  |          |  |  |          | 
		
	
		
			
				|  |  |      |  |  |      | 
		
	
		
			
				|  |  | def main(): |  |  | def main(): | 
		
	
		
			
				|  |  |     ray.init(num_cpus=4) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     ray.init(num_cpus=3) | 
		
	
		
			
				|  |  |     import argparse |  |  |     import argparse | 
		
	
		
			
				|  |  |     args = parse_arguments(argparse) |  |  |     args = parse_arguments(argparse) | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
	
		
			
				|  | 
 |