You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
9.1 KiB

  1. from typing import Dict, Optional, Union
  2. from ray.rllib.env.base_env import BaseEnv
  3. from ray.rllib.evaluation import RolloutWorker
  4. from ray.rllib.evaluation.episode import Episode
  5. from ray.rllib.evaluation.episode_v2 import EpisodeV2
  6. from ray.rllib.policy import Policy
  7. from ray.rllib.utils.typing import PolicyID
  8. import stormpy
  9. import stormpy.core
  10. import stormpy.simulator
  11. from collections import deque
  12. import stormpy.shields
  13. import stormpy.logic
  14. import stormpy.examples
  15. import stormpy.examples.files
  16. import os
  17. import gymnasium as gym
  18. import minigrid
  19. import numpy as np
  20. import ray
  21. from ray.tune import register_env
  22. from ray.rllib.algorithms.ppo import PPOConfig
  23. from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator
  24. from ray import tune, air
  25. from ray.rllib.algorithms.callbacks import DefaultCallbacks
  26. from ray.tune.logger import pretty_print
  27. from ray.rllib.utils.numpy import one_hot
  28. from ray.rllib.algorithms import ppo
  29. from ray.rllib.models.preprocessors import get_preprocessor
  30. import matplotlib.pyplot as plt
  31. import argparse
  32. class MyCallbacks(DefaultCallbacks):
  33. def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
  34. # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
  35. env = base_env.get_sub_environments()[0]
  36. episode.user_data["count"] = 0
  37. # print(env.printGrid())
  38. # print(env.action_space.n)
  39. # print(env.actions)
  40. # print(env.mission)
  41. # print(env.observation_space)
  42. # img = env.get_frame()
  43. # plt.imshow(img)
  44. # plt.show()
  45. def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
  46. episode.user_data["count"] = episode.user_data["count"] + 1
  47. env = base_env.get_sub_environments()[0]
  48. # print(env.env.env.printGrid())
  49. def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
  50. # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
  51. env = base_env.get_sub_environments()[0]
  52. # print(env.env.env.printGrid())
  53. # print(episode.user_data["count"])
  54. class OneHotWrapper(gym.core.ObservationWrapper):
  55. def __init__(self, env, vector_index, framestack):
  56. super().__init__(env)
  57. self.framestack = framestack
  58. # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types.
  59. # +4: Direction.
  60. self.single_frame_dim = 49 * (11 + 6 + 3) + 4
  61. self.init_x = None
  62. self.init_y = None
  63. self.x_positions = []
  64. self.y_positions = []
  65. self.x_y_delta_buffer = deque(maxlen=100)
  66. self.vector_index = vector_index
  67. self.frame_buffer = deque(maxlen=self.framestack)
  68. for _ in range(self.framestack):
  69. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  70. self.observation_space = gym.spaces.Box(
  71. 0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32
  72. )
  73. def observation(self, obs):
  74. # Debug output: max-x/y positions to watch exploration progress.
  75. if self.step_count == 0:
  76. for _ in range(self.framestack):
  77. self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
  78. if self.vector_index == 0:
  79. if self.x_positions:
  80. max_diff = max(
  81. np.sqrt(
  82. (np.array(self.x_positions) - self.init_x) ** 2
  83. + (np.array(self.y_positions) - self.init_y) ** 2
  84. )
  85. )
  86. self.x_y_delta_buffer.append(max_diff)
  87. print(
  88. "100-average dist travelled={}".format(
  89. np.mean(self.x_y_delta_buffer)
  90. )
  91. )
  92. self.x_positions = []
  93. self.y_positions = []
  94. self.init_x = self.agent_pos[0]
  95. self.init_y = self.agent_pos[1]
  96. self.x_positions.append(self.agent_pos[0])
  97. self.y_positions.append(self.agent_pos[1])
  98. # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
  99. objects = one_hot(obs[:, :, 0], depth=11)
  100. colors = one_hot(obs[:, :, 1], depth=6)
  101. states = one_hot(obs[:, :, 2], depth=3)
  102. all_ = np.concatenate([objects, colors, states], -1)
  103. all_flat = np.reshape(all_, (-1,))
  104. direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
  105. single_frame = np.concatenate([all_flat, direction])
  106. self.frame_buffer.append(single_frame)
  107. return np.concatenate(self.frame_buffer)
  108. def parse_arguments(argparse):
  109. parser = argparse.ArgumentParser()
  110. # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0")
  111. parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0")
  112. parser.add_argument("--seed", type=int, help="seed for environment", default=1)
  113. parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32)
  114. parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees")
  115. parser.add_argument("--grid_path", default="Grid.txt")
  116. parser.add_argument("--prism_path", default="Grid.PRISM")
  117. args = parser.parse_args()
  118. return args
  119. def env_creater(config):
  120. name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
  121. # name = config.get("name", "MiniGrid-Empty-8x8-v0")
  122. framestack = config.get("framestack", 4)
  123. env = gym.make(name)
  124. # env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
  125. env = minigrid.wrappers.ImgObsWrapper(env)
  126. env = OneHotWrapper(env,
  127. config.vector_index if hasattr(config, "vector_index") else 0,
  128. framestack=framestack
  129. )
  130. return env
  131. def create_shield(grid_file, prism_path):
  132. os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}")
  133. f = open(prism_path, "a")
  134. f.write("label \"AgentIsInLava\" = AgentIsInLava;")
  135. f.close()
  136. program = stormpy.parse_prism_program(prism_path)
  137. formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]"
  138. formulas = stormpy.parse_properties_for_prism_program(formula_str, program)
  139. options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
  140. options.set_build_state_valuations(True)
  141. options.set_build_choice_labels(True)
  142. options.set_build_all_labels()
  143. model = stormpy.build_sparse_model_with_options(program, options)
  144. shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
  145. result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
  146. assert result.has_scheduler
  147. assert result.has_shield
  148. shield = result.shield
  149. stormpy.shields.export_shield(model, shield,"Grid.shield")
  150. return shield.construct(), model
  151. def export_grid_to_text(env, grid_file):
  152. f = open(grid_file, "w")
  153. print(env)
  154. f.write(env.printGrid(init=True))
  155. # f.write(env.pprint_grid())
  156. f.close()
  157. def create_environment(args):
  158. env_id= args.env
  159. env = gym.make(env_id)
  160. env.reset()
  161. return env
  162. def main():
  163. args = parse_arguments(argparse)
  164. env = create_environment(args)
  165. ray.init(num_cpus=3)
  166. # print(env.pprint_grid())
  167. # print(env.printGrid(init=False))
  168. grid_file = args.grid_path
  169. export_grid_to_text(env, grid_file)
  170. prism_path = args.prism_path
  171. shield, model = create_shield(grid_file, prism_path)
  172. for state_id in model.states:
  173. choices = shield.get_choice(state_id)
  174. print(F"Allowed choices in state {state_id}, are {choices.choice_map} ")
  175. env_name = "mini-grid"
  176. register_env(env_name, env_creater)
  177. algo =(
  178. PPOConfig()
  179. .rollouts(num_rollout_workers=1)
  180. .resources(num_gpus=0)
  181. .environment(env="mini-grid")
  182. .framework("torch")
  183. .callbacks(MyCallbacks)
  184. .training(model={
  185. "fcnet_hiddens": [256,256],
  186. "fcnet_activation": "relu",
  187. })
  188. .build()
  189. )
  190. episode_reward = 0
  191. terminated = truncated = False
  192. obs, info = env.reset()
  193. # while not terminated and not truncated:
  194. # action = algo.compute_single_action(obs)
  195. # obs, reward, terminated, truncated = env.step(action)
  196. for i in range(30):
  197. result = algo.train()
  198. print(pretty_print(result))
  199. if i % 5 == 0:
  200. checkpoint_dir = algo.save()
  201. print(f"Checkpoint saved in directory {checkpoint_dir}")
  202. ray.shutdown()
  203. if __name__ == '__main__':
  204. main()