You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

103 lines
3.4 KiB

  1. import gymnasium as gym
  2. import minigrid
  3. from ray.tune import register_env
  4. from ray.rllib.algorithms.ppo import PPOConfig
  5. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  6. from ray.tune.logger import pretty_print
  7. from ray.rllib.models import ModelCatalog
  8. from ray.rllib.algorithms.algorithm import Algorithm
  9. from torch_action_mask_model import TorchActionMaskModel
  10. from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
  11. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  12. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  13. from callbacks import MyCallbacks
  14. from ray.tune.logger import TBXLogger
  15. import imageio
  16. import matplotlib.pyplot as plt
  17. def shielding_env_creater(config):
  18. name = config.get("name", "MiniGrid-LavaSlipperyS12-v2")
  19. framestack = config.get("framestack", 4)
  20. args = config.get("args", None)
  21. args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
  22. args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
  23. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  24. env = gym.make(name)
  25. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
  26. # env = minigrid.wrappers.ImgObsWrapper(env)
  27. # env = ImgObsWrapper(env)
  28. env = OneHotShieldingWrapper(env,
  29. config.vector_index if hasattr(config, "vector_index") else 0,
  30. framestack=framestack
  31. )
  32. return env
  33. def register_minigrid_shielding_env(args):
  34. env_name = "mini-grid-shielding"
  35. register_env(env_name, shielding_env_creater)
  36. ModelCatalog.register_custom_model(
  37. "shielding_model",
  38. TorchActionMaskModel
  39. )
  40. import argparse
  41. args = parse_arguments(argparse)
  42. register_minigrid_shielding_env(args)
  43. # Use the Algorithm's `from_checkpoint` utility to get a new algo instance
  44. # that has the exact same state as the old one, from which the checkpoint was
  45. # created in the first place:
  46. path_to_checkpoint = '/home/tknoll/Documents/Projects/log_results/PPO-shielding:full-evaluations:10-steps:20000-env:MiniGrid-LavaSlipperyS12-v2/PPO/PPO_mini-grid-shielding_8cd74_00000_0_2023-09-13_14-10-38/checkpoint_000005'
  47. algo = Algorithm.from_checkpoint(path_to_checkpoint)
  48. # Continue training.
  49. name = "MiniGrid-LavaSlipperyS12-v2"
  50. shield_creator = MiniGridShieldHandler(F"./{args.grid_path}_1.txt", args.grid_to_prism_binary_path, F"./{args.prism_path}_1.prism", args.formula)
  51. env = gym.make(name)
  52. env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
  53. # env = minigrid.wrappers.ImgObsWrapper(env)
  54. # env = ImgObsWrapper(env)
  55. env = OneHotShieldingWrapper(env,
  56. 0,
  57. framestack=4
  58. )
  59. episode_reward = 0
  60. terminated = truncated = False
  61. obs, info = env.reset()
  62. i = 0
  63. filenames = []
  64. while not terminated and not truncated:
  65. action = algo.compute_single_action(obs)
  66. obs, reward, terminated, truncated, info = env.step(action)
  67. episode_reward += reward
  68. filename = F"./frames/{i}.jpg"
  69. img = env.get_frame()
  70. plt.imsave(filename, img)
  71. filenames.append(filename)
  72. i = i + 1
  73. import imageio
  74. images = []
  75. for filename in filenames:
  76. images.append(imageio.imread(filename))
  77. imageio.mimsave('./movie.gif', images)