| 
					
					
						
							
						
					
					
				 | 
				@ -10,7 +10,7 @@ from ray.tune.logger import UnifiedLogger | 
			
		
		
	
		
			
				 | 
				 | 
				from ray.rllib.models import ModelCatalog | 
				 | 
				 | 
				from ray.rllib.models import ModelCatalog | 
			
		
		
	
		
			
				 | 
				 | 
				from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger | 
				 | 
				 | 
				from ray.tune.logger import pretty_print, UnifiedLogger, CSVLogger | 
			
		
		
	
		
			
				 | 
				 | 
				from ray.rllib.algorithms.algorithm import Algorithm | 
				 | 
				 | 
				from ray.rllib.algorithms.algorithm import Algorithm | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				from ray.air import session | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				from torch_action_mask_model import TorchActionMaskModel | 
				 | 
				 | 
				from torch_action_mask_model import TorchActionMaskModel | 
			
		
		
	
		
			
				 | 
				 | 
				from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
				 | 
				 | 
				from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -25,11 +25,13 @@ def shielding_env_creater(config): | 
			
		
		
	
		
			
				 | 
				 | 
				    name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") | 
				 | 
				 | 
				    name = config.get("name", "MiniGrid-LavaCrossingS9N3-v0") | 
			
		
		
	
		
			
				 | 
				 | 
				    framestack = config.get("framestack", 4) | 
				 | 
				 | 
				    framestack = config.get("framestack", 4) | 
			
		
		
	
		
			
				 | 
				 | 
				    args = config.get("args", None) | 
				 | 
				 | 
				    args = config.get("args", None) | 
			
		
		
	
		
			
				 | 
				 | 
				    args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				    args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				     | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    args.grid_path = F"{args.expname}_{args.grid_path}_{config.worker_index}.txt" | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    args.prism_path = F"{args.expname}_{args.prism_path}_{config.worker_index}.prism"    | 
			
		
		
	
		
			
				 | 
				 | 
				    shielding = config.get("shielding", False)    | 
				 | 
				 | 
				    shielding = config.get("shielding", False)    | 
			
		
		
	
		
			
				 | 
				 | 
				    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    shield_creator = MiniGridShieldHandler(grid_file=args.grid_path,  | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                           grid_to_prism_path=args.grid_to_prism_binary_path, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                           prism_path=args.prism_path, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                           formula=args.formula) | 
			
		
		
	
		
			
				 | 
				 | 
				     | 
				 | 
				 | 
				     | 
			
		
		
	
		
			
				 | 
				 | 
				    env = gym.make(name) | 
				 | 
				 | 
				    env = gym.make(name) | 
			
		
		
	
		
			
				 | 
				 | 
				    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) | 
				 | 
				 | 
				    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
						
							
						
					
					
				 | 
				@ -64,7 +66,10 @@ def ppo(args): | 
			
		
		
	
		
			
				 | 
				 | 
				        .rollouts(num_rollout_workers=args.workers) | 
				 | 
				 | 
				        .rollouts(num_rollout_workers=args.workers) | 
			
		
		
	
		
			
				 | 
				 | 
				        .resources(num_gpus=0) | 
				 | 
				 | 
				        .resources(num_gpus=0) | 
			
		
		
	
		
			
				 | 
				 | 
				        .environment( env="mini-grid-shielding", | 
				 | 
				 | 
				        .environment( env="mini-grid-shielding", | 
			
		
		
	
		
			
				 | 
				 | 
				                      env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                      env_config={"name": args.env, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                  "args": args,                                   | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                  "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                  },) | 
			
		
		
	
		
			
				 | 
				 | 
				        .framework("torch") | 
				 | 
				 | 
				        .framework("torch") | 
			
		
		
	
		
			
				 | 
				 | 
				        .callbacks(MyCallbacks) | 
				 | 
				 | 
				        .callbacks(MyCallbacks) | 
			
		
		
	
		
			
				 | 
				 | 
				        .evaluation(evaluation_config={  | 
				 | 
				 | 
				        .evaluation(evaluation_config={  | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -72,7 +77,9 @@ def ppo(args): | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "evaluation_duration": 10, | 
				 | 
				 | 
				                                        "evaluation_duration": 10, | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "evaluation_num_workers":1, | 
				 | 
				 | 
				                                        "evaluation_num_workers":1, | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "env": "mini-grid-shielding",  | 
				 | 
				 | 
				                                        "env": "mini-grid-shielding",  | 
			
		
		
	
		
			
				 | 
				 | 
				                                        "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                        "env_config": {"name": args.env,  | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                                       "args": args,  | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                                                       "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})         | 
			
		
		
	
		
			
				 | 
				 | 
				        .rl_module(_enable_rl_module_api = False) | 
				 | 
				 | 
				        .rl_module(_enable_rl_module_api = False) | 
			
		
		
	
		
			
				 | 
				 | 
				        .debugging(logger_config={ | 
				 | 
				 | 
				        .debugging(logger_config={ | 
			
		
		
	
		
			
				 | 
				 | 
				            "type": UnifiedLogger,  | 
				 | 
				 | 
				            "type": UnifiedLogger,  | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -87,7 +94,8 @@ def ppo(args): | 
			
		
		
	
		
			
				 | 
				 | 
				                           metric="episode_reward_mean", | 
				 | 
				 | 
				                           metric="episode_reward_mean", | 
			
		
		
	
		
			
				 | 
				 | 
				                           mode="max", | 
				 | 
				 | 
				                           mode="max", | 
			
		
		
	
		
			
				 | 
				 | 
				                           num_samples=1, | 
				 | 
				 | 
				                           num_samples=1, | 
			
		
		
	
		
			
				 | 
				 | 
				                           trial_name_creator=trial_name_creator,                            | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                           trial_name_creator=trial_name_creator, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				                       ), | 
				 | 
				 | 
				                       ), | 
			
		
		
	
		
			
				 | 
				 | 
				                        run_config=air.RunConfig( | 
				 | 
				 | 
				                        run_config=air.RunConfig( | 
			
		
		
	
		
			
				 | 
				 | 
				                                stop = {"episode_reward_mean": 94, | 
				 | 
				 | 
				                                stop = {"episode_reward_mean": 94, | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
						
							
						
					
					
				 | 
				@ -144,7 +152,7 @@ def ppo(args): | 
			
		
		
	
		
			
				 | 
				 | 
				         | 
				 | 
				 | 
				         | 
			
		
		
	
		
			
				 | 
				 | 
				     | 
				 | 
				 | 
				     | 
			
		
		
	
		
			
				 | 
				 | 
				def main(): | 
				 | 
				 | 
				def main(): | 
			
		
		
	
		
			
				 | 
				 | 
				    ray.init(num_cpus=4) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    ray.init(num_cpus=3) | 
			
		
		
	
		
			
				 | 
				 | 
				    import argparse | 
				 | 
				 | 
				    import argparse | 
			
		
		
	
		
			
				 | 
				 | 
				    args = parse_arguments(argparse) | 
				 | 
				 | 
				    args = parse_arguments(argparse) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |