|  |  | @ -12,13 +12,14 @@ from ray.rllib.models import ModelCatalog | 
			
		
	
		
			
				
					|  |  |  | from ray.rllib.algorithms.algorithm import Algorithm | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from torch_action_mask_model import TorchActionMaskModel | 
			
		
	
		
			
				
					|  |  |  | from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
			
		
	
		
			
				
					|  |  |  | from helpers import parse_arguments, create_log_dir, ShieldingConfig | 
			
		
	
		
			
				
					|  |  |  | from shieldhandlers import MiniGridShieldHandler, create_shield_query | 
			
		
	
		
			
				
					|  |  |  | from callbacks import MyCallbacks | 
			
		
	
		
			
				
					|  |  |  | from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper | 
			
		
	
		
			
				
					|  |  |  | from utils import parse_arguments, create_log_dir, ShieldingConfig | 
			
		
	
		
			
				
					|  |  |  | from utils import MiniGridShieldHandler, create_shield_query | 
			
		
	
		
			
				
					|  |  |  | from callbacks import CustomCallback | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | from ray.tune.logger import TBXLogger    | 
			
		
	
		
			
				
					|  |  |  | import imageio | 
			
		
	
		
			
				
					|  |  |  | import os | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | import matplotlib.pyplot as plt | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -32,8 +33,8 @@ def shielding_env_creater(config): | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name, randomize_start=False) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=False) | 
			
		
	
		
			
				
					|  |  |  |     # env = minigrid.wrappers.ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     # env = ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     env = OneHotShieldingWrapper(env, | 
			
		
	
	
		
			
				
					|  |  | @ -41,6 +42,8 @@ def shielding_env_creater(config): | 
			
		
	
		
			
				
					|  |  |  |                         framestack=framestack | 
			
		
	
		
			
				
					|  |  |  |                         ) | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     env.randomize_start = False | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     return env | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -62,32 +65,53 @@ register_minigrid_shielding_env(args) | 
			
		
	
		
			
				
					|  |  |  | # Use the Algorithm's `from_checkpoint` utility to get a new algo instance | 
			
		
	
		
			
				
					|  |  |  | # that has the exact same state as the old one, from which the checkpoint was | 
			
		
	
		
			
				
					|  |  |  | # created in the first place: | 
			
		
	
		
			
				
					|  |  |  | path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005' | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | algo = Algorithm.from_checkpoint(path_to_checkpoint) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | # Continue training. | 
			
		
	
		
			
				
					|  |  |  | name = "MiniGrid-LavaSlipperyS12-v2" | 
			
		
	
		
			
				
					|  |  |  | shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | env = gym.make(name) | 
			
		
	
		
			
				
					|  |  |  | env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) | 
			
		
	
		
			
				
					|  |  |  | # env = minigrid.wrappers.ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  | # env = ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  | env = OneHotShieldingWrapper(env, | 
			
		
	
		
			
				
					|  |  |  | # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:none-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030', 'No_shield'), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000030", "Rel_06_high"), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_06_med"), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_06_low"), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000016", "Rel_1_high"), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_1_med"), | 
			
		
	
		
			
				
					|  |  |  | #                         ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_1_low")] | 
			
		
	
		
			
				
					|  |  |  | checkpoints = [ | 
			
		
	
		
			
				
					|  |  |  |     # ('/home/knolli/Documents/University/Thesis/log_results/sh:none-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "no_shielding"), | 
			
		
	
		
			
				
					|  |  |  |                 # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_09"), | 
			
		
	
		
			
				
					|  |  |  |                 # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:1.0-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_1")] | 
			
		
	
		
			
				
					|  |  |  | ('/home/knolli/Documents/University/Thesis/logresults/exp/trial_0_2024-01-09_22-39-43/checkpoint_000002', 'v3')] | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_prob.yaml/checkpoint_000060', "Shielded_Gif")] | 
			
		
	
		
			
				
					|  |  |  | for path_to_checkpoint, gif_name in checkpoints: | 
			
		
	
		
			
				
					|  |  |  |     algo = Algorithm.from_checkpoint(path_to_checkpoint) | 
			
		
	
		
			
				
					|  |  |  |     policy = algo.get_policy() | 
			
		
	
		
			
				
					|  |  |  |     # Continue training. | 
			
		
	
		
			
				
					|  |  |  |     name = "MiniGrid-LavaSlipperyS12-v0" | 
			
		
	
		
			
				
					|  |  |  |     shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     env = gym.make(name, randomize_start=False, probability_forward=3/9, probability_direct_neighbour=5/9, probability_next_neighbour=7/9,) | 
			
		
	
		
			
				
					|  |  |  |     env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True) | 
			
		
	
		
			
				
					|  |  |  |     # env = minigrid.wrappers.ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     # env = ImgObsWrapper(env) | 
			
		
	
		
			
				
					|  |  |  |     env = OneHotShieldingWrapper(env, | 
			
		
	
		
			
				
					|  |  |  |                         0, | 
			
		
	
		
			
				
					|  |  |  |                         framestack=4 | 
			
		
	
		
			
				
					|  |  |  |                         ) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | episode_reward = 0 | 
			
		
	
		
			
				
					|  |  |  | terminated = truncated = False | 
			
		
	
		
			
				
					|  |  |  |     episode_reward = 0 | 
			
		
	
		
			
				
					|  |  |  |     terminated = truncated = False | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  | obs, info = env.reset() | 
			
		
	
		
			
				
					|  |  |  | i = 0 | 
			
		
	
		
			
				
					|  |  |  | filenames = [] | 
			
		
	
		
			
				
					|  |  |  | while not terminated and not truncated: | 
			
		
	
		
			
				
					|  |  |  |     obs, info = env.reset() | 
			
		
	
		
			
				
					|  |  |  |     i = 0 | 
			
		
	
		
			
				
					|  |  |  |     filenames = [] | 
			
		
	
		
			
				
					|  |  |  |     while not terminated and not truncated: | 
			
		
	
		
			
				
					|  |  |  |         action = algo.compute_single_action(obs) | 
			
		
	
		
			
				
					|  |  |  |         policy_actions = policy.compute_single_action(obs) | 
			
		
	
		
			
				
					|  |  |  |         # print(f'Policy actions {policy_actions}') | 
			
		
	
		
			
				
					|  |  |  |         # print(f'Policy actions {policy_actions.logits}') | 
			
		
	
		
			
				
					|  |  |  |         policy_action = policy_actions[2]['action_dist_inputs'].argmax() | 
			
		
	
		
			
				
					|  |  |  |         # print(f'The action is: {action} vs policy action {policy_action}') | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         if policy_action != action: | 
			
		
	
		
			
				
					|  |  |  |             print('policy action deviated') | 
			
		
	
		
			
				
					|  |  |  |             action = policy_action | 
			
		
	
		
			
				
					|  |  |  |         obs, reward, terminated, truncated, info = env.step(action) | 
			
		
	
		
			
				
					|  |  |  |         episode_reward += reward | 
			
		
	
		
			
				
					|  |  |  |         filename = F"./frames/{i}.jpg" | 
			
		
	
	
		
			
				
					|  |  | @ -96,9 +120,12 @@ while not terminated and not truncated: | 
			
		
	
		
			
				
					|  |  |  |         filenames.append(filename) | 
			
		
	
		
			
				
					|  |  |  |         i  = i + 1 | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  | import imageio | 
			
		
	
		
			
				
					|  |  |  | images = [] | 
			
		
	
		
			
				
					|  |  |  | for filename in filenames: | 
			
		
	
		
			
				
					|  |  |  |     import imageio | 
			
		
	
		
			
				
					|  |  |  |     images = [] | 
			
		
	
		
			
				
					|  |  |  |     for filename in filenames: | 
			
		
	
		
			
				
					|  |  |  |         images.append(imageio.imread(filename)) | 
			
		
	
		
			
				
					|  |  |  | imageio.mimsave('./movie.gif', images) | 
			
		
	
		
			
				
					|  |  |  |     imageio.mimsave(F'./{gif_name}.gif', images) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     for filename in filenames: | 
			
		
	
		
			
				
					|  |  |  |         os.remove(filename) | 
			
		
	
		
			
				
					|  |  |  |      |