From 175171c035fac78e1d8535d2f6cf15ed8c0e2ad0 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Tue, 9 Jan 2024 22:53:18 +0100 Subject: [PATCH] added jpg remove on gif create --- examples/shields/rl/callbacks.py | 8 +-- examples/shields/rl/checkpoint.py | 115 ++++++++++++++++++------------ examples/shields/rl/utils.py | 5 +- 3 files changed, 78 insertions(+), 50 deletions(-) diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py index 4cdbac5..1c856e4 100644 --- a/examples/shields/rl/callbacks.py +++ b/examples/shields/rl/callbacks.py @@ -12,9 +12,10 @@ from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks +from ray.tune import Callback import matplotlib.pyplot as plt - + class CustomCallback(DefaultCallbacks): def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: @@ -27,11 +28,10 @@ class CustomCallback(DefaultCallbacks): episode.hist_data["goals_reached"] = [] episode.hist_data["ran_into_adversary"] = [] - def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: episode.user_data["count"] = episode.user_data["count"] + 1 env = base_env.get_sub_environments()[0] - + # print(env.printGrid()) def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: @@ -40,7 +40,7 @@ class CustomCallback(DefaultCallbacks): agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) ran_into_adversary = False - if hasattr(env, "adversaries"): + if hasattr(env, "adversaries") and env.adversaries: adversaries = env.adversaries.values() for adversary in adversaries: if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: diff --git a/examples/shields/rl/checkpoint.py b/examples/shields/rl/checkpoint.py index 65ca139..05cfff8 100644 --- a/examples/shields/rl/checkpoint.py +++ b/examples/shields/rl/checkpoint.py @@ -12,13 +12,14 @@ from ray.rllib.models import ModelCatalog from ray.rllib.algorithms.algorithm import Algorithm from torch_action_mask_model import TorchActionMaskModel -from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir, ShieldingConfig -from shieldhandlers import MiniGridShieldHandler, create_shield_query -from callbacks import MyCallbacks +from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper +from utils import parse_arguments, create_log_dir, ShieldingConfig +from utils import MiniGridShieldHandler, create_shield_query +from callbacks import CustomCallback from ray.tune.logger import TBXLogger import imageio +import os import matplotlib.pyplot as plt @@ -32,8 +33,8 @@ def shielding_env_creater(config): shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) - env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) + env = gym.make(name, randomize_start=False) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=False) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotShieldingWrapper(env, @@ -41,6 +42,8 @@ def shielding_env_creater(config): framestack=framestack ) + env.randomize_start = False + return env @@ -62,43 +65,67 @@ register_minigrid_shielding_env(args) # Use the Algorithm's `from_checkpoint` utility to get a new algo instance # that has the exact same state as the old one, from which the checkpoint was # created in the first place: -path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005' - - -algo = Algorithm.from_checkpoint(path_to_checkpoint) - -# Continue training. -name = "MiniGrid-LavaSlipperyS12-v2" -shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) - -env = gym.make(name) -env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) -# env = minigrid.wrappers.ImgObsWrapper(env) -# env = ImgObsWrapper(env) -env = OneHotShieldingWrapper(env, - 0, - framestack=4 - ) - -episode_reward = 0 -terminated = truncated = False - -obs, info = env.reset() -i = 0 -filenames = [] -while not terminated and not truncated: - action = algo.compute_single_action(obs) - obs, reward, terminated, truncated, info = env.step(action) - episode_reward += reward - filename = F"./frames/{i}.jpg" - img = env.get_frame() - plt.imsave(filename, img) - filenames.append(filename) - i = i + 1 +# checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:none-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030', 'No_shield'), +# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000030", "Rel_06_high"), +# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_06_med"), +# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_06_low"), +# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000016", "Rel_1_high"), +# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_1_med"), +# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_1_low")] +checkpoints = [ + # ('/home/knolli/Documents/University/Thesis/log_results/sh:none-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "no_shielding"), + # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_09"), + # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:1.0-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_1")] +('/home/knolli/Documents/University/Thesis/logresults/exp/trial_0_2024-01-09_22-39-43/checkpoint_000002', 'v3')] + +# checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_prob.yaml/checkpoint_000060', "Shielded_Gif")] +for path_to_checkpoint, gif_name in checkpoints: + algo = Algorithm.from_checkpoint(path_to_checkpoint) + policy = algo.get_policy() + # Continue training. + name = "MiniGrid-LavaSlipperyS12-v0" + shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) + + env = gym.make(name, randomize_start=False, probability_forward=3/9, probability_direct_neighbour=5/9, probability_next_neighbour=7/9,) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True) + # env = minigrid.wrappers.ImgObsWrapper(env) + # env = ImgObsWrapper(env) + env = OneHotShieldingWrapper(env, + 0, + framestack=4 + ) + + episode_reward = 0 + terminated = truncated = False -import imageio -images = [] -for filename in filenames: - images.append(imageio.imread(filename)) -imageio.mimsave('./movie.gif', images) + obs, info = env.reset() + i = 0 + filenames = [] + while not terminated and not truncated: + action = algo.compute_single_action(obs) + policy_actions = policy.compute_single_action(obs) + # print(f'Policy actions {policy_actions}') + # print(f'Policy actions {policy_actions.logits}') + policy_action = policy_actions[2]['action_dist_inputs'].argmax() + # print(f'The action is: {action} vs policy action {policy_action}') + + if policy_action != action: + print('policy action deviated') + action = policy_action + obs, reward, terminated, truncated, info = env.step(action) + episode_reward += reward + filename = F"./frames/{i}.jpg" + img = env.get_frame() + plt.imsave(filename, img) + filenames.append(filename) + i = i + 1 + + import imageio + images = [] + for filename in filenames: + images.append(imageio.imread(filename)) + imageio.mimsave(F'./{gif_name}.gif', images) + + for filename in filenames: + os.remove(filename) \ No newline at end of file diff --git a/examples/shields/rl/utils.py b/examples/shields/rl/utils.py index b526e18..283eee2 100644 --- a/examples/shields/rl/utils.py +++ b/examples/shields/rl/utils.py @@ -235,7 +235,7 @@ def extract_doors(env): def extract_adversaries(env): adv = [] - if not hasattr(env, "adversaries"): + if not hasattr(env, "adversaries") or not env.adversaries: return [] for color, adversary in env.adversaries.items(): @@ -286,7 +286,8 @@ def parse_arguments(argparse): "MiniGrid-AdvSimple-8x8-v0", "MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N3-v0", - "MiniGrid-LavaSlipperyCliffS12-v0" + "MiniGrid-LavaSlipperyCliffS12-v0", + "MiniGrid-LavaFaultyS12-30-v0", ]) # parser.add_argument("--seed", type=int, help="seed for environment", default=None)