| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -25,6 +25,7 @@ import numpy as np | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import ray | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.tune import register_env | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.algorithms.ppo import PPOConfig | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.algorithms.dqn.dqn import DQNConfig | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray import tune, air | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.algorithms.callbacks import DefaultCallbacks | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -37,7 +38,7 @@ from ray.rllib.utils.torch_utils import FLOAT_MIN | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from ray.rllib.models.preprocessors import get_preprocessor | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from MaskEnvironments import ParametricActionsMiniGridEnv | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from MaskModels import TorchActionMaskModel | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from Wrapper import OneHotWrapper, MiniGridEnvWrapper, ImgObsWrapper | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					from Wrapper import OneHotWrapper, MiniGridEnvWrapper | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					import matplotlib.pyplot as plt | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -62,7 +63,7 @@ class MyCallbacks(DefaultCallbacks): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         episode.user_data["count"] = episode.user_data["count"] + 1 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         env = base_env.get_sub_environments()[0] | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         print(env.env.env.printGrid()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					         #print(env.env.env.printGrid()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -83,6 +84,7 @@ def parse_arguments(argparse): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--grid_path", default="Grid.txt") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--prism_path", default="Grid.PRISM") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--no_masking", default=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parser.parse_args() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -108,13 +110,13 @@ def env_creater_custom(config): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        framestack=framestack | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					                        ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    obs = env.observation_space.sample() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    obs2, infos = env.reset(seed=None, options={}) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # obs = env.observation_space.sample() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # obs2, infos = env.reset(seed=None, options={}) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    print(F"Obs is {obs} before reset. After reset: {obs2}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # print(F"Obs is {obs} before reset. After reset: {obs2}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    print(F"Created Custom Minigrid Environment is {env}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # print(F"Created Custom Minigrid Environment is {env}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return env | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -194,12 +196,16 @@ def create_environment(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return env | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parse_arguments(argparse) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def register_custom_minigrid_env(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env_name = "mini-grid" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_env(env_name, env_creater_custom) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ModelCatalog.register_custom_model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        "pa_model",  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        TorchActionMaskModel | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def create_shield_dict(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env = create_environment(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ray.init(num_cpus=3) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # print(env.pprint_grid()) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    # print(env.printGrid(init=False)) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				
					@ -215,19 +221,21 @@ def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    #     choices = shield.get_choice(state_id) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    #     print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    env_name = "mini-grid" | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_env(env_name, env_creater_custom) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ModelCatalog.register_custom_model( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        "pa_model",  | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        TorchActionMaskModel | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    return shield_dict | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def ppo(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ray.init(num_cpus=3) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_custom_minigrid_env() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_dict = create_shield_dict(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = (PPOConfig() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .rollouts(num_rollout_workers=1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .resources(num_gpus=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .environment(env="mini-grid", env_config={"shield": shield_dict }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .framework("torch")        | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .experimental(_disable_preprocessor_api=False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .callbacks(MyCallbacks) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .rl_module(_enable_rl_module_api = False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        .training(_enable_learner_api=False ,model={ | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				
					@ -256,9 +264,47 @@ def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            checkpoint_dir = algo.save() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            print(f"Checkpoint saved in directory {checkpoint_dir}") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					             | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ray.shutdown() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def dqn(args): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = DQNConfig() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    register_custom_minigrid_env() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    shield_dict = create_shield_dict(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    replay_config = config.replay_buffer_config.update( | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        { | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "capacity": 60000, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "prioritized_replay_alpha": 0.5, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "prioritized_replay_beta": 0.5, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "prioritized_replay_eps": 3e-6, | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        } | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.training(replay_buffer_config=replay_config, model={     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model": "pa_model", | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					            "custom_model_config" : {"shield": shield_dict, "no_masking": args.no_masking} | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.resources(num_gpus=0) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.rollouts(num_rollout_workers=1) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.framework("torch") | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.callbacks(MyCallbacks) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.rl_module(_enable_rl_module_api = False) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    config = config.environment(env="mini-grid", env_config={"shield": shield_dict }) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					     | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					def main(): | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    args = parse_arguments(argparse) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    if args.algorithm == "ppo": | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        ppo(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    elif args.algorithm == "dqn": | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					        dqn(args) | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    ray.shutdown() | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					if __name__ == '__main__': | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				
					    main() |