|  | @ -9,47 +9,50 @@ from minigrid.core.actions import Actions | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | import time |  |  | import time | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig |  |  |  | 
		
	
		
			
				|  |  | from sb3utils import MiniGridSbShieldingWrapper |  |  |  | 
		
	
		
			
				|  |  |  |  |  | from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig | 
		
	
		
			
				|  |  |  |  |  | from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def mask_fn(env: gym.Env): |  |  | def mask_fn(env: gym.Env): | 
		
	
		
			
				|  |  |     return env.create_action_mask() |  |  |     return env.create_action_mask() | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | def main(): |  |  | def main(): | 
		
	
		
			
				|  |  |     import argparse |  |  |  | 
		
	
		
			
				|  |  |     args = parse_arguments(argparse) |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |     args.grid_path = F"{args.grid_path}.txt" |  |  |  | 
		
	
		
			
				|  |  |     args.prism_path = F"{args.prism_path}.prism" |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |     shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     args = parse_sb3_arguments() | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     formula = args.formula | 
		
	
		
			
				|  |  |  |  |  |     shield_value = args.shield_value | 
		
	
		
			
				|  |  |  |  |  |     shield_comparison = args.shield_comparison | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) | 
		
	
		
			
				|  |  |     env = gym.make(args.env, render_mode="rgb_array") |  |  |     env = gym.make(args.env, render_mode="rgb_array") | 
		
	
		
			
				|  |  |     env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) | 
		
	
		
			
				|  |  |     env = ActionMasker(env, mask_fn) |  |  |     env = ActionMasker(env, mask_fn) | 
		
	
		
			
				|  |  |     model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) |  |  |     model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |     steps = args.steps |  |  |     steps = args.steps | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |     model.learn(steps) |  |  |     model.learn(steps) | 
		
	
		
			
				|  |  |   |  |  |  | 
		
	
		
			
				|  |  |   #W  mean_reward, std_reward = evaluate_policy(model, model.get_env()) |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     print("Learning done, hit enter") | 
		
	
		
			
				|  |  |  |  |  |     input("") | 
		
	
		
			
				|  |  |     vec_env = model.get_env() |  |  |     vec_env = model.get_env() | 
		
	
		
			
				|  |  |     obs = vec_env.reset() |  |  |     obs = vec_env.reset() | 
		
	
		
			
				|  |  |     terminated = truncated = False |  |  |     terminated = truncated = False | 
		
	
		
			
				|  |  |     while not terminated and not truncated: |  |  |     while not terminated and not truncated: | 
		
	
		
			
				|  |  |         action_masks = None |  |  |         action_masks = None | 
		
	
		
			
				|  |  |         action, _states = model.predict(obs, action_masks=action_masks) |  |  |         action, _states = model.predict(obs, action_masks=action_masks) | 
		
	
		
			
				|  |  |  |  |  |         print(action) | 
		
	
		
			
				|  |  |         obs, reward, terminated, truncated, info = env.step(action) |  |  |         obs, reward, terminated, truncated, info = env.step(action) | 
		
	
		
			
				|  |  |         # action, _states = model.predict(obs, deterministic=True) |  |  |         # action, _states = model.predict(obs, deterministic=True) | 
		
	
		
			
				|  |  |         # obs, rewards, dones, info = vec_env.step(action) |  |  |         # obs, rewards, dones, info = vec_env.step(action) | 
		
	
		
			
				|  |  |         vec_env.render("human") |  |  |         vec_env.render("human") | 
		
	
		
			
				|  |  |         time.sleep(0.2) |  |  |         time.sleep(0.2) | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |      |  |  |  | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | if __name__ == '__main__': |  |  | if __name__ == '__main__': | 
		
	
		
			
				|  |  |     main() |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     main() |