You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							72 lines
						
					
					
						
							2.3 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							72 lines
						
					
					
						
							2.3 KiB
						
					
					
				| from sb3_contrib import MaskablePPO | |
| from sb3_contrib.common.maskable.evaluation import evaluate_policy | |
| from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy | |
| from sb3_contrib.common.wrappers import ActionMasker | |
| from stable_baselines3.common.callbacks import BaseCallback | |
| 
 | |
| import gymnasium as gym | |
| 
 | |
| from minigrid.core.actions import Actions | |
| 
 | |
| import time | |
| 
 | |
| from helpers import parse_arguments, create_log_dir, ShieldingConfig | |
| from shieldhandlers import MiniGridShieldHandler, create_shield_query | |
| from wrappers import MiniGridSbShieldingWrapper | |
| 
 | |
| class CustomCallback(BaseCallback): | |
|     def __init__(self, verbose: int = 0, env=None): | |
|         super(CustomCallback, self).__init__(verbose) | |
|         self.env = env | |
|          | |
|          | |
|     def _on_step(self) -> bool: | |
|         print(self.env.printGrid()) | |
|         return super()._on_step() | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| def mask_fn(env: gym.Env): | |
|     return env.create_action_mask() | |
|      | |
| 
 | |
| 
 | |
| def main(): | |
|     import argparse | |
|     args = parse_arguments(argparse) | |
|      | |
|     args.grid_path = F"{args.grid_path}.txt" | |
|     args.prism_path = F"{args.prism_path}.prism" | |
|      | |
|     shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | |
|      | |
|     env = gym.make(args.env, render_mode="rgb_array") | |
|     env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) | |
|     env = ActionMasker(env, mask_fn) | |
|     callback = CustomCallback(1, env) | |
|     model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) | |
|      | |
|     steps = args.steps | |
|      | |
|      | |
|     model.learn(steps, callback=callback) | |
|   | |
|   #W  mean_reward, std_reward = evaluate_policy(model, model.get_env()) | |
|      | |
|     vec_env = model.get_env() | |
|     obs = vec_env.reset() | |
|     terminated = truncated = False | |
|     while not terminated and not truncated: | |
|         action_masks = None | |
|         action, _states = model.predict(obs, action_masks=action_masks) | |
|         obs, reward, terminated, truncated, info = env.step(action) | |
|         # action, _states = model.predict(obs, deterministic=True) | |
|         # obs, rewards, dones, info = vec_env.step(action) | |
|         vec_env.render("human") | |
|         time.sleep(0.2) | |
|      | |
|      | |
| 
 | |
| if __name__ == '__main__': | |
|     main() |