From 71acf4e2cc97396bed29148c9c0eba0840dac654 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Mon, 27 Nov 2023 12:27:23 +0100 Subject: [PATCH] added checkpoint & sandbox --- examples/shields/rl/checkpoint.py | 104 ++++++++++++++++++++++++++++++ examples/shields/rl/sandbox.py | 66 +++++++++++++++++++ 2 files changed, 170 insertions(+) create mode 100644 examples/shields/rl/checkpoint.py create mode 100644 examples/shields/rl/sandbox.py diff --git a/examples/shields/rl/checkpoint.py b/examples/shields/rl/checkpoint.py new file mode 100644 index 0000000..65ca139 --- /dev/null +++ b/examples/shields/rl/checkpoint.py @@ -0,0 +1,104 @@ + + +import gymnasium as gym +import minigrid + +from ray.tune import register_env +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.dqn.dqn import DQNConfig +from ray.tune.logger import pretty_print +from ray.rllib.models import ModelCatalog + +from ray.rllib.algorithms.algorithm import Algorithm + +from torch_action_mask_model import TorchActionMaskModel +from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper +from helpers import parse_arguments, create_log_dir, ShieldingConfig +from shieldhandlers import MiniGridShieldHandler, create_shield_query +from callbacks import MyCallbacks + +from ray.tune.logger import TBXLogger +import imageio + +import matplotlib.pyplot as plt + + +def shielding_env_creater(config): + name = config.get("name", "MiniGrid-LavaSlipperyS12-v2") + framestack = config.get("framestack", 4) + args = config.get("args", None) + args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + + env = gym.make(name) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) + # env = minigrid.wrappers.ImgObsWrapper(env) + # env = ImgObsWrapper(env) + env = OneHotShieldingWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + + return env + + +def register_minigrid_shielding_env(args): + env_name = "mini-grid-shielding" + register_env(env_name, shielding_env_creater) + + ModelCatalog.register_custom_model( + "shielding_model", + TorchActionMaskModel + ) + +import argparse +args = parse_arguments(argparse) + +register_minigrid_shielding_env(args) + +# Use the Algorithm's `from_checkpoint` utility to get a new algo instance +# that has the exact same state as the old one, from which the checkpoint was +# created in the first place: +path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005' + + +algo = Algorithm.from_checkpoint(path_to_checkpoint) + +# Continue training. +name = "MiniGrid-LavaSlipperyS12-v2" +shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) + +env = gym.make(name) +env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) +# env = minigrid.wrappers.ImgObsWrapper(env) +# env = ImgObsWrapper(env) +env = OneHotShieldingWrapper(env, + 0, + framestack=4 + ) + +episode_reward = 0 +terminated = truncated = False + +obs, info = env.reset() +i = 0 +filenames = [] +while not terminated and not truncated: + action = algo.compute_single_action(obs) + obs, reward, terminated, truncated, info = env.step(action) + episode_reward += reward + filename = F"./frames/{i}.jpg" + img = env.get_frame() + plt.imsave(filename, img) + filenames.append(filename) + i = i + 1 + +import imageio +images = [] +for filename in filenames: + images.append(imageio.imread(filename)) +imageio.mimsave('./movie.gif', images) + \ No newline at end of file diff --git a/examples/shields/rl/sandbox.py b/examples/shields/rl/sandbox.py new file mode 100644 index 0000000..cc77202 --- /dev/null +++ b/examples/shields/rl/sandbox.py @@ -0,0 +1,66 @@ +import minigrid +import gymnasium as gym +import random +import matplotlib.pyplot as plt + +from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper + +def main(): + #samples = random.choices([0.0, 1.0], weights=(0.25, 0.75), k=100_000_000) + # print(samples) + + # print(sum(samples)) + #print(sum(samples) / len(samples)) + + + + names = [ + "MiniGrid-Adv-8x8-v0", + "MiniGrid-AdvSimple-8x8-v0", + "MiniGrid-AdvSlippery-8x8-v0", + "MiniGrid-AdvLava-8x8-v0", + # "MiniGrid-SingleDoor-7x6-v0", + # "MiniGrid-DoubleDoor-10x8-v0", + # "MiniGrid-DoubleDoor-12x12-v0", + # "MiniGrid-DoubleDoor-16x16-v0", + # "MiniGrid-LavaSlipperyS12-v0", + # "MiniGrid-LavaSlipperyS12-v1", + # "MiniGrid-LavaSlipperyS12-v2", + # "MiniGrid-LavaSlipperyS12-v3", + # "MiniGrid-LavaCrossingS9N3-v0", + # "MiniGrid-SimpleCrossingS9N3-v0", + + # "MiniGrid-ObstructedMaze-1Dlh-v0", + # "MiniGrid-BlockedUnlockPickup-v0", + # "MiniGrid-KeyCorridorS6R3-v0", + # "MiniGrid-LockedRoom-v0", + # "MiniGrid-KeyCorridorS3R1-v0", + # "MiniGrid-LavaGapS7-v0", + # "MiniGrid-DoorKey-8x8-v0", + # "MiniGrid-Dynamic-Obstacles-8x8-v0", + # "MiniGrid-Empty-Random-6x6-v0", + # "MiniGrid-Fetch-6x6-N2-v0", + # "MiniGrid-FourRooms-v0", + # "MiniGrid-LavaGapS7-v0", + # "MiniGrid-RedBlueDoors-6x6-v0", + ] + + for name in names: + env = gym.make(name) + env = RGBImgPartialObsWrapper(env) + env = ImgObsWrapper(env) + + env.reset() + + img = env.get_frame(highlight=False) + plt.title(name) + plt.imshow(img) + f = open(F"{name}.txt", "w") + f.write(env.printGrid(init=True)) + f.close() + + + plt.show() + +if __name__ == '__main__': + main() \ No newline at end of file