Browse Source

added jpg remove on gif create

refactoring
Thomas Knoll 1 year ago
parent
commit
175171c035
  1. 8
      examples/shields/rl/callbacks.py
  2. 115
      examples/shields/rl/checkpoint.py
  3. 5
      examples/shields/rl/utils.py

8
examples/shields/rl/callbacks.py

@ -12,9 +12,10 @@ from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
from ray.tune import Callback
import matplotlib.pyplot as plt
class CustomCallback(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode, env_index, **kwargs) -> None:
@ -27,11 +28,10 @@ class CustomCallback(DefaultCallbacks):
episode.hist_data["goals_reached"] = []
episode.hist_data["ran_into_adversary"] = []
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies, episode, env_index, **kwargs) -> None:
@ -40,7 +40,7 @@ class CustomCallback(DefaultCallbacks):
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
ran_into_adversary = False
if hasattr(env, "adversaries"):
if hasattr(env, "adversaries") and env.adversaries:
adversaries = env.adversaries.values()
for adversary in adversaries:
if adversary.cur_pos[0] == env.agent_pos[0] and adversary.cur_pos[1] == env.agent_pos[1]:

115
examples/shields/rl/checkpoint.py

@ -12,13 +12,14 @@ from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.algorithm import Algorithm
from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
from utils import parse_arguments, create_log_dir, ShieldingConfig
from utils import MiniGridShieldHandler, create_shield_query
from callbacks import CustomCallback
from ray.tune.logger import TBXLogger
import imageio
import os
import matplotlib.pyplot as plt
@ -32,8 +33,8 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
env = gym.make(name, randomize_start=False)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=False)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
@ -41,6 +42,8 @@ def shielding_env_creater(config):
framestack=framestack
)
env.randomize_start = False
return env
@ -62,43 +65,67 @@ register_minigrid_shielding_env(args)
# Use the Algorithm's `from_checkpoint` utility to get a new algo instance
# that has the exact same state as the old one, from which the checkpoint was
# created in the first place:
path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005'
algo = Algorithm.from_checkpoint(path_to_checkpoint)
# Continue training.
name = "MiniGrid-LavaSlipperyS12-v2"
shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
0,
framestack=4
)
episode_reward = 0
terminated = truncated = False
obs, info = env.reset()
i = 0
filenames = []
while not terminated and not truncated:
action = algo.compute_single_action(obs)
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
filename = F"./frames/{i}.jpg"
img = env.get_frame()
plt.imsave(filename, img)
filenames.append(filename)
i = i + 1
# checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:none-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030', 'No_shield'),
# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000030", "Rel_06_high"),
# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_06_med"),
# ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_06_low"),
# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000016", "Rel_1_high"),
# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_1_med"),
# ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_1_low")]
checkpoints = [
# ('/home/knolli/Documents/University/Thesis/log_results/sh:none-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "no_shielding"),
# ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_09"),
# ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:1.0-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_1")]
('/home/knolli/Documents/University/Thesis/logresults/exp/trial_0_2024-01-09_22-39-43/checkpoint_000002', 'v3')]
# checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_prob.yaml/checkpoint_000060', "Shielded_Gif")]
for path_to_checkpoint, gif_name in checkpoints:
algo = Algorithm.from_checkpoint(path_to_checkpoint)
policy = algo.get_policy()
# Continue training.
name = "MiniGrid-LavaSlipperyS12-v0"
shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula)
env = gym.make(name, randomize_start=False, probability_forward=3/9, probability_direct_neighbour=5/9, probability_next_neighbour=7/9,)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
0,
framestack=4
)
episode_reward = 0
terminated = truncated = False
import imageio
images = []
for filename in filenames:
images.append(imageio.imread(filename))
imageio.mimsave('./movie.gif', images)
obs, info = env.reset()
i = 0
filenames = []
while not terminated and not truncated:
action = algo.compute_single_action(obs)
policy_actions = policy.compute_single_action(obs)
# print(f'Policy actions {policy_actions}')
# print(f'Policy actions {policy_actions.logits}')
policy_action = policy_actions[2]['action_dist_inputs'].argmax()
# print(f'The action is: {action} vs policy action {policy_action}')
if policy_action != action:
print('policy action deviated')
action = policy_action
obs, reward, terminated, truncated, info = env.step(action)
episode_reward += reward
filename = F"./frames/{i}.jpg"
img = env.get_frame()
plt.imsave(filename, img)
filenames.append(filename)
i = i + 1
import imageio
images = []
for filename in filenames:
images.append(imageio.imread(filename))
imageio.mimsave(F'./{gif_name}.gif', images)
for filename in filenames:
os.remove(filename)

5
examples/shields/rl/utils.py

@ -235,7 +235,7 @@ def extract_doors(env):
def extract_adversaries(env):
adv = []
if not hasattr(env, "adversaries"):
if not hasattr(env, "adversaries") or not env.adversaries:
return []
for color, adversary in env.adversaries.items():
@ -286,7 +286,8 @@ def parse_arguments(argparse):
"MiniGrid-AdvSimple-8x8-v0",
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-LavaSlipperyCliffS12-v0"
"MiniGrid-LavaSlipperyCliffS12-v0",
"MiniGrid-LavaFaultyS12-30-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)

Loading…
Cancel
Save