Thomas Knoll
1 year ago
2 changed files with 170 additions and 0 deletions
@ -0,0 +1,104 @@ |
|||||
|
|
||||
|
|
||||
|
import gymnasium as gym |
||||
|
import minigrid |
||||
|
|
||||
|
from ray.tune import register_env |
||||
|
from ray.rllib.algorithms.ppo import PPOConfig |
||||
|
from ray.rllib.algorithms.dqn.dqn import DQNConfig |
||||
|
from ray.tune.logger import pretty_print |
||||
|
from ray.rllib.models import ModelCatalog |
||||
|
|
||||
|
from ray.rllib.algorithms.algorithm import Algorithm |
||||
|
|
||||
|
from torch_action_mask_model import TorchActionMaskModel |
||||
|
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper |
||||
|
from helpers import parse_arguments, create_log_dir, ShieldingConfig |
||||
|
from shieldhandlers import MiniGridShieldHandler, create_shield_query |
||||
|
from callbacks import MyCallbacks |
||||
|
|
||||
|
from ray.tune.logger import TBXLogger |
||||
|
import imageio |
||||
|
|
||||
|
import matplotlib.pyplot as plt |
||||
|
|
||||
|
|
||||
|
def shielding_env_creater(config): |
||||
|
name = config.get("name", "MiniGrid-LavaSlipperyS12-v2") |
||||
|
framestack = config.get("framestack", 4) |
||||
|
args = config.get("args", None) |
||||
|
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" |
||||
|
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" |
||||
|
|
||||
|
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) |
||||
|
|
||||
|
env = gym.make(name) |
||||
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) |
||||
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
||||
|
# env = ImgObsWrapper(env) |
||||
|
env = OneHotShieldingWrapper(env, |
||||
|
config.vector_index if hasattr(config, "vector_index") else 0, |
||||
|
framestack=framestack |
||||
|
) |
||||
|
|
||||
|
|
||||
|
return env |
||||
|
|
||||
|
|
||||
|
def register_minigrid_shielding_env(args): |
||||
|
env_name = "mini-grid-shielding" |
||||
|
register_env(env_name, shielding_env_creater) |
||||
|
|
||||
|
ModelCatalog.register_custom_model( |
||||
|
"shielding_model", |
||||
|
TorchActionMaskModel |
||||
|
) |
||||
|
|
||||
|
import argparse |
||||
|
args = parse_arguments(argparse) |
||||
|
|
||||
|
register_minigrid_shielding_env(args) |
||||
|
|
||||
|
# Use the Algorithm's `from_checkpoint` utility to get a new algo instance |
||||
|
# that has the exact same state as the old one, from which the checkpoint was |
||||
|
# created in the first place: |
||||
|
path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005' |
||||
|
|
||||
|
|
||||
|
algo = Algorithm.from_checkpoint(path_to_checkpoint) |
||||
|
|
||||
|
# Continue training. |
||||
|
name = "MiniGrid-LavaSlipperyS12-v2" |
||||
|
shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula) |
||||
|
|
||||
|
env = gym.make(name) |
||||
|
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) |
||||
|
# env = minigrid.wrappers.ImgObsWrapper(env) |
||||
|
# env = ImgObsWrapper(env) |
||||
|
env = OneHotShieldingWrapper(env, |
||||
|
0, |
||||
|
framestack=4 |
||||
|
) |
||||
|
|
||||
|
episode_reward = 0 |
||||
|
terminated = truncated = False |
||||
|
|
||||
|
obs, info = env.reset() |
||||
|
i = 0 |
||||
|
filenames = [] |
||||
|
while not terminated and not truncated: |
||||
|
action = algo.compute_single_action(obs) |
||||
|
obs, reward, terminated, truncated, info = env.step(action) |
||||
|
episode_reward += reward |
||||
|
filename = F"./frames/{i}.jpg" |
||||
|
img = env.get_frame() |
||||
|
plt.imsave(filename, img) |
||||
|
filenames.append(filename) |
||||
|
i = i + 1 |
||||
|
|
||||
|
import imageio |
||||
|
images = [] |
||||
|
for filename in filenames: |
||||
|
images.append(imageio.imread(filename)) |
||||
|
imageio.mimsave('./movie.gif', images) |
||||
|
|
@ -0,0 +1,66 @@ |
|||||
|
import minigrid |
||||
|
import gymnasium as gym |
||||
|
import random |
||||
|
import matplotlib.pyplot as plt |
||||
|
|
||||
|
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper |
||||
|
|
||||
|
def main(): |
||||
|
#samples = random.choices([0.0, 1.0], weights=(0.25, 0.75), k=100_000_000) |
||||
|
# print(samples) |
||||
|
|
||||
|
# print(sum(samples)) |
||||
|
#print(sum(samples) / len(samples)) |
||||
|
|
||||
|
|
||||
|
|
||||
|
names = [ |
||||
|
"MiniGrid-Adv-8x8-v0", |
||||
|
"MiniGrid-AdvSimple-8x8-v0", |
||||
|
"MiniGrid-AdvSlippery-8x8-v0", |
||||
|
"MiniGrid-AdvLava-8x8-v0", |
||||
|
# "MiniGrid-SingleDoor-7x6-v0", |
||||
|
# "MiniGrid-DoubleDoor-10x8-v0", |
||||
|
# "MiniGrid-DoubleDoor-12x12-v0", |
||||
|
# "MiniGrid-DoubleDoor-16x16-v0", |
||||
|
# "MiniGrid-LavaSlipperyS12-v0", |
||||
|
# "MiniGrid-LavaSlipperyS12-v1", |
||||
|
# "MiniGrid-LavaSlipperyS12-v2", |
||||
|
# "MiniGrid-LavaSlipperyS12-v3", |
||||
|
# "MiniGrid-LavaCrossingS9N3-v0", |
||||
|
# "MiniGrid-SimpleCrossingS9N3-v0", |
||||
|
|
||||
|
# "MiniGrid-ObstructedMaze-1Dlh-v0", |
||||
|
# "MiniGrid-BlockedUnlockPickup-v0", |
||||
|
# "MiniGrid-KeyCorridorS6R3-v0", |
||||
|
# "MiniGrid-LockedRoom-v0", |
||||
|
# "MiniGrid-KeyCorridorS3R1-v0", |
||||
|
# "MiniGrid-LavaGapS7-v0", |
||||
|
# "MiniGrid-DoorKey-8x8-v0", |
||||
|
# "MiniGrid-Dynamic-Obstacles-8x8-v0", |
||||
|
# "MiniGrid-Empty-Random-6x6-v0", |
||||
|
# "MiniGrid-Fetch-6x6-N2-v0", |
||||
|
# "MiniGrid-FourRooms-v0", |
||||
|
# "MiniGrid-LavaGapS7-v0", |
||||
|
# "MiniGrid-RedBlueDoors-6x6-v0", |
||||
|
] |
||||
|
|
||||
|
for name in names: |
||||
|
env = gym.make(name) |
||||
|
env = RGBImgPartialObsWrapper(env) |
||||
|
env = ImgObsWrapper(env) |
||||
|
|
||||
|
env.reset() |
||||
|
|
||||
|
img = env.get_frame(highlight=False) |
||||
|
plt.title(name) |
||||
|
plt.imshow(img) |
||||
|
f = open(F"{name}.txt", "w") |
||||
|
f.write(env.printGrid(init=True)) |
||||
|
f.close() |
||||
|
|
||||
|
|
||||
|
plt.show() |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue