You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
6.3 KiB

2 months ago
  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from ray.rllib.algorithms.algorithm import Algorithm
  9. from torch_action_mask_model import TorchActionMaskModel
  10. from rllibutils import OneHotShieldingWrapper, MiniGridShieldingWrapper
  11. from utils import parse_arguments, create_log_dir, ShieldingConfig
  12. from utils import MiniGridShieldHandler, create_shield_query
  13. from callbacks import CustomCallback
  14. from ray.tune.logger import TBXLogger
  15. import imageio
  16. import os
  17. import matplotlib.pyplot as plt
  18. def shielding_env_creater(config):
  19. name = config.get("name", "MiniGrid-LavaSlipperyS12-v2")
  20. framestack = config.get("framestack", 4)
  21. args = config.get("args", None)
  22. args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
  23. args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
  24. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  25. env = gym.make(name, randomize_start=False)
  26. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=False)
  27. # env = minigrid.wrappers.ImgObsWrapper(env)
  28. # env = ImgObsWrapper(env)
  29. env = OneHotShieldingWrapper(env,
  30. config.vector_index if hasattr(config, "vector_index") else 0,
  31. framestack=framestack
  32. )
  33. env.randomize_start = False
  34. return env
  35. def register_minigrid_shielding_env(args):
  36. env_name = "mini-grid-shielding"
  37. register_env(env_name, shielding_env_creater)
  38. ModelCatalog.register_custom_model(
  39. "shielding_model",
  40. TorchActionMaskModel
  41. )
  42. import argparse
  43. args = parse_arguments(argparse)
  44. register_minigrid_shielding_env(args)
  45. # Use the Algorithm's `from_checkpoint` utility to get a new algo instance
  46. # that has the exact same state as the old one, from which the checkpoint was
  47. # created in the first place:
  48. # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:none-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030', 'No_shield'),
  49. # ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000030", "Rel_06_high"),
  50. # ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_06_med"),
  51. # ("/home/knolli/Documents/University/Thesis/log_results/Relative_06/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_06_low"),
  52. # ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_high.yaml/checkpoint_000016", "Rel_1_high"),
  53. # ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_medium.yaml/checkpoint_000030", "Rel_1_med"),
  54. # ("/home/knolli/Documents/University/Thesis/log_results/RELATIVE_1/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:adv_config_slippery_low.yaml/checkpoint_000030", "Rel_1_low")]
  55. checkpoints = [
  56. # ('/home/knolli/Documents/University/Thesis/log_results/sh:none-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "no_shielding"),
  57. # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:0.9-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_09"),
  58. # ('/home/knolli/Documents/University/Thesis/log_results/sh:full-value:1.0-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_pro.yaml/checkpoint_000070', "shielding_1")]
  59. ('/home/knolli/Documents/University/Thesis/logresults/exp/trial_0_2024-01-09_22-39-43/checkpoint_000002', 'v3')]
  60. # checkpoints = [('/home/knolli/Documents/University/Thesis/log_results/sh:full-env:MiniGrid-LavaSlipperyS12-v2-conf:slippery_high_prob.yaml/checkpoint_000060', "Shielded_Gif")]
  61. for path_to_checkpoint, gif_name in checkpoints:
  62. algo = Algorithm.from_checkpoint(path_to_checkpoint)
  63. policy = algo.get_policy()
  64. # Continue training.
  65. name = "MiniGrid-LavaSlipperyS12-v0"
  66. shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula)
  67. env = gym.make(name, randomize_start=False, probability_forward=3/9, probability_direct_neighbour=5/9, probability_next_neighbour=7/9,)
  68. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=True)
  69. # env = minigrid.wrappers.ImgObsWrapper(env)
  70. # env = ImgObsWrapper(env)
  71. env = OneHotShieldingWrapper(env,
  72. 0,
  73. framestack=4
  74. )
  75. episode_reward = 0
  76. terminated = truncated = False
  77. obs, info = env.reset()
  78. i = 0
  79. filenames = []
  80. while not terminated and not truncated:
  81. action = algo.compute_single_action(obs)
  82. policy_actions = policy.compute_single_action(obs)
  83. # print(f'Policy actions {policy_actions}')
  84. # print(f'Policy actions {policy_actions.logits}')
  85. policy_action = policy_actions[2]['action_dist_inputs'].argmax()
  86. # print(f'The action is: {action} vs policy action {policy_action}')
  87. if policy_action != action:
  88. print('policy action deviated')
  89. action = policy_action
  90. obs, reward, terminated, truncated, info = env.step(action)
  91. episode_reward += reward
  92. filename = F"./frames/{i}.jpg"
  93. img = env.get_frame()
  94. plt.imsave(filename, img)
  95. filenames.append(filename)
  96. i = i + 1
  97. import imageio
  98. images = []
  99. for filename in filenames:
  100. images.append(imageio.imread(filename))
  101. imageio.mimsave(F'./{gif_name}.gif', images)
  102. for filename in filenames:
  103. os.remove(filename)