Browse Source

added jpg remove on gif create

refactoring
Thomas Knoll 1 year ago
parent
commit
175171c035
  1. 8
      examples/shields/rl/callbacks.py
  2. 115
      examples/shields/rl/checkpoint.py
  3. 5
      examples/shields/rl/utils.py

8
examples/shields/rl/callbacks.py

@ -12,9 +12,10 @@ from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
from ray.tune import Callback
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class CustomCallback(DefaultCallbacks): class CustomCallback(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
@ -27,11 +28,10 @@ class CustomCallback(DefaultCallbacks):
episode.hist_data["goals_reached"] = [] episode.hist_data["goals_reached"] = []
episode.hist_data["ran_into_adversary"] = [] episode.hist_data["ran_into_adversary"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1 episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None: def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
@ -40,7 +40,7 @@ class CustomCallback(DefaultCallbacks):
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
ran_into_adversary = False ran_into_adversary = False
if hasattr(env, "adversaries"): if hasattr(env, "adversaries") and env.adversaries:
adversaries = env.adversaries.values() adversaries = env.adversaries.values()
for adversary in adversaries: for adversary in adversaries:
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]: if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:

115
examples/shields/rl/checkpoint.py

@ -12,13 +12,14 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm import Algorithm
from torch_action_mask_model import TorchActionMaskModel from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig from utils import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query from utils import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks from callbacks import CustomCallback
from ray.tune.logger import TBXLogger from ray.tune.logger import TBXLogger
import imageio import imageio
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -32,8 +33,8 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name, randomize_start=False)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=False)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env, env = OneHotShieldingWrapper(env,
@ -41,6 +42,8 @@ def shielding_env_creater(config):
framestack=framestack framestack=framestack
) )
env.randomize_start = False
return env return env
@ -62,43 +65,67 @@ register_minigrid_shielding_env(args)
# Use the Algorithm's `from_checkpoint` utility to get a new algo instance # Use the Algorithm's `from_checkpoint` utility to get a new algo instance
# that has the exact same state as the old one, from which the checkpoint was # that has the exact same state as the old one, from which the checkpoint was
# created in the first place: # created in the first place:
path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005' # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:none-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030', 'No_shield'),
# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000030", "Rel_06_high"),
# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_06_med"),
algo = Algorithm.from_checkpoint(path_to_checkpoint) # ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_06_low"),
# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000016", "Rel_1_high"),
# Continue training. # ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_1_med"),
name = "MiniGrid-LavaSlipperyS12-v2" # ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_1_low")]
shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) checkpoints = [
# ('/home/knolli/Documents/University/Thesis/log_results/sh:none-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "no_shielding"),
env = gym.make(name) # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_09"),
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:1.0-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_1")]
# env = minigrid.wrappers.ImgObsWrapper(env) ('/home/knolli/Documents/University/Thesis/logresults/exp/trial_0_2024-01-09_22-39-43/checkpoint_000002', 'v3')]
# env = ImgObsWrapper(env) # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_prob.yaml/checkpoint_000060', "Shielded_Gif")]
env = OneHotShieldingWrapper(env, for path_to_checkpoint, gif_name in checkpoints:
0, algo = Algorithm.from_checkpoint(path_to_checkpoint)
framestack=4 policy = algo.get_policy()
) # Continue training.
name = "MiniGrid-LavaSlipperyS12-v0"
episode_reward = 0 shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula)
terminated = truncated = False env = gym.make(name, randomize_start=False, probability_forward=3/9, probability_direct_neighbour=5/9, probability_next_neighbour=7/9,)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)
obs, info = env.reset() # env = minigrid.wrappers.ImgObsWrapper(env)
i = 0 # env = ImgObsWrapper(env)
filenames = [] env = OneHotShieldingWrapper(env,
while not terminated and not truncated: 0,
action = algo.compute_single_action(obs) framestack=4
obs, reward, terminated, truncated, info = env.step(action) )
episode_reward += reward episode_reward = 0
filename = F"./frames/{i}.jpg" terminated = truncated = False
img = env.get_frame()
plt.imsave(filename, img)
filenames.append(filename)
i = i + 1
import imageio obs, info = env.reset()
images = [] i = 0
for filename in filenames: filenames = []
images.append(imageio.imread(filename)) while not terminated and not truncated:
imageio.mimsave('./movie.gif', images) action = algo.compute_single_action(obs)
policy_actions = policy.compute_single_action(obs)
# print(f'Policy actions {policy_actions}')
# print(f'Policy actions {policy_actions.logits}')
policy_action = policy_actions[2]['action_dist_inputs'].argmax()
# print(f'The action is: {action} vs policy action {policy_action}')
if policy_action != action:
print('policy action deviated')
action = policy_action
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
filename = F"./frames/{i}.jpg"
img = env.get_frame()
plt.imsave(filename, img)
filenames.append(filename)
i = i + 1
import imageio
images = []
for filename in filenames:
images.append(imageio.imread(filename))
imageio.mimsave(F'./{gif_name}.gif', images)
for filename in filenames:
os.remove(filename)

5
examples/shields/rl/utils.py

@ -235,7 +235,7 @@ def extract_doors(env):
def extract_adversaries(env): def extract_adversaries(env):
adv = [] adv = []
if not hasattr(env, "adversaries"): if not hasattr(env, "adversaries") or not env.adversaries:
return [] return []
for color, adversary in env.adversaries.items(): for color, adversary in env.adversaries.items():
@ -286,7 +286,8 @@ def parse_arguments(argparse):
"MiniGrid-AdvSimple-8x8-v0", "MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0", "MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyCliffS12-v0" "MiniGrid-LavaSlipperyCliffS12-v0",
"MiniGrid-LavaFaultyS12-30-v0",
]) ])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) # parser.add_argument("--seed", type=int, help="seed for environment", default=None)

|||||||
100:0
Loading…
Cancel
Save