1 changed files with 267 additions and 0 deletions
			
			
		| @ -0,0 +1,267 @@ | |||||
|  | from typing import Dict, Optional, Union | ||||
|  | from ray.rllib.env.base_env import BaseEnv | ||||
|  | from ray.rllib.evaluation import RolloutWorker | ||||
|  | from ray.rllib.evaluation.episode import Episode | ||||
|  | from ray.rllib.evaluation.episode_v2 import EpisodeV2 | ||||
|  | from ray.rllib.policy import Policy | ||||
|  | from ray.rllib.utils.typing import PolicyID | ||||
|  | import stormpy | ||||
|  | import stormpy.core | ||||
|  | import stormpy.simulator | ||||
|  | 
 | ||||
|  | from collections import deque | ||||
|  | 
 | ||||
|  | import stormpy.shields | ||||
|  | import stormpy.logic | ||||
|  | 
 | ||||
|  | import stormpy.examples | ||||
|  | import stormpy.examples.files | ||||
|  | import os | ||||
|  | 
 | ||||
|  | import gymnasium as gym | ||||
|  | 
 | ||||
|  | import minigrid | ||||
|  | import numpy as np | ||||
|  | 
 | ||||
|  | import ray | ||||
|  | from ray.tune import register_env | ||||
|  | from ray.rllib.algorithms.ppo import PPOConfig | ||||
|  | from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator | ||||
|  | from ray import tune, air | ||||
|  | from ray.rllib.algorithms.callbacks import DefaultCallbacks | ||||
|  | from ray.tune.logger import pretty_print | ||||
|  | from ray.rllib.utils.numpy import one_hot | ||||
|  | from ray.rllib.algorithms import ppo | ||||
|  | 
 | ||||
|  | from ray.rllib.models.preprocessors import get_preprocessor | ||||
|  | 
 | ||||
|  | import matplotlib.pyplot as plt | ||||
|  | 
 | ||||
|  | import argparse | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | class MyCallbacks(DefaultCallbacks): | ||||
|  |     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: | ||||
|  |         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") | ||||
|  |         env = base_env.get_sub_environments()[0] | ||||
|  |         episode.user_data["count"] = 0 | ||||
|  |         # print(env.printGrid()) | ||||
|  |         # print(env.action_space.n) | ||||
|  |         # print(env.actions) | ||||
|  |         # print(env.mission) | ||||
|  |         # print(env.observation_space) | ||||
|  |         # img = env.get_frame() | ||||
|  |         # plt.imshow(img) | ||||
|  |         # plt.show() | ||||
|  |         | ||||
|  |     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: | ||||
|  |          episode.user_data["count"] = episode.user_data["count"] + 1 | ||||
|  |          env = base_env.get_sub_environments()[0] | ||||
|  |         #  print(env.env.env.printGrid()) | ||||
|  |      | ||||
|  |     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: | ||||
|  |         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") | ||||
|  |         env = base_env.get_sub_environments()[0] | ||||
|  |         # print(env.env.env.printGrid()) | ||||
|  |         # print(episode.user_data["count"]) | ||||
|  |          | ||||
|  |      | ||||
|  |         | ||||
|  | 
 | ||||
|  | class OneHotWrapper(gym.core.ObservationWrapper): | ||||
|  |     def __init__(self, env, vector_index, framestack): | ||||
|  |         super().__init__(env) | ||||
|  |         self.framestack = framestack | ||||
|  |         # 49=7x7 field of vision; 11=object types; 6=colors; 3=state types. | ||||
|  |         # +4: Direction. | ||||
|  |         self.single_frame_dim = 49 * (11 + 6 + 3) + 4 | ||||
|  |         self.init_x = None | ||||
|  |         self.init_y = None | ||||
|  |         self.x_positions = [] | ||||
|  |         self.y_positions = [] | ||||
|  |         self.x_y_delta_buffer = deque(maxlen=100) | ||||
|  |         self.vector_index = vector_index | ||||
|  |         self.frame_buffer = deque(maxlen=self.framestack) | ||||
|  |         for _ in range(self.framestack): | ||||
|  |             self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | ||||
|  | 
 | ||||
|  |         self.observation_space = gym.spaces.Box( | ||||
|  |             0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32 | ||||
|  |         ) | ||||
|  | 
 | ||||
|  |     def observation(self, obs): | ||||
|  |         # Debug output: max-x/y positions to watch exploration progress. | ||||
|  |         if self.step_count == 0: | ||||
|  |             for _ in range(self.framestack): | ||||
|  |                 self.frame_buffer.append(np.zeros((self.single_frame_dim,))) | ||||
|  |             if self.vector_index == 0: | ||||
|  |                 if self.x_positions: | ||||
|  |                     max_diff = max( | ||||
|  |                         np.sqrt( | ||||
|  |                             (np.array(self.x_positions) - self.init_x) ** 2 | ||||
|  |                             + (np.array(self.y_positions) - self.init_y) ** 2 | ||||
|  |                         ) | ||||
|  |                     ) | ||||
|  |                     self.x_y_delta_buffer.append(max_diff) | ||||
|  |                     print( | ||||
|  |                         "100-average dist travelled={}".format( | ||||
|  |                             np.mean(self.x_y_delta_buffer) | ||||
|  |                         ) | ||||
|  |                     ) | ||||
|  |                     self.x_positions = [] | ||||
|  |                     self.y_positions = [] | ||||
|  |                 self.init_x = self.agent_pos[0] | ||||
|  |                 self.init_y = self.agent_pos[1] | ||||
|  | 
 | ||||
|  |        | ||||
|  |         self.x_positions.append(self.agent_pos[0]) | ||||
|  |         self.y_positions.append(self.agent_pos[1]) | ||||
|  | 
 | ||||
|  |         # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten. | ||||
|  |         objects = one_hot(obs[:, :, 0], depth=11) | ||||
|  |         colors = one_hot(obs[:, :, 1], depth=6) | ||||
|  |         states = one_hot(obs[:, :, 2], depth=3) | ||||
|  |        | ||||
|  |         all_ = np.concatenate([objects, colors, states], -1) | ||||
|  |         all_flat = np.reshape(all_, (-1,)) | ||||
|  |         direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32) | ||||
|  |         single_frame = np.concatenate([all_flat, direction]) | ||||
|  |         self.frame_buffer.append(single_frame) | ||||
|  |         return np.concatenate(self.frame_buffer) | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def parse_arguments(argparse): | ||||
|  |     parser = argparse.ArgumentParser() | ||||
|  |     # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") | ||||
|  |     parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") | ||||
|  |     parser.add_argument("--seed", type=int, help="seed for environment", default=1) | ||||
|  |     parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) | ||||
|  |     parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") | ||||
|  |     parser.add_argument("--grid_path", default="Grid.txt") | ||||
|  |     parser.add_argument("--prism_path", default="Grid.PRISM") | ||||
|  |      | ||||
|  |     args = parser.parse_args() | ||||
|  |      | ||||
|  |     return args | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def env_creater(config): | ||||
|  |     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") | ||||
|  |     # name = config.get("name", "MiniGrid-Empty-8x8-v0") | ||||
|  |     framestack = config.get("framestack", 4) | ||||
|  |      | ||||
|  |     env = gym.make(name) | ||||
|  |     # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) | ||||
|  |     env = minigrid.wrappers.ImgObsWrapper(env) | ||||
|  |     env = OneHotWrapper(env, | ||||
|  |                         config.vector_index if hasattr(config, "vector_index") else 0, | ||||
|  |                         framestack=framestack | ||||
|  |                         ) | ||||
|  |        | ||||
|  | 
 | ||||
|  |     return env | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def create_shield(grid_file, prism_path): | ||||
|  |     os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") | ||||
|  |      | ||||
|  |     f = open(prism_path, "a") | ||||
|  |     f.write("label \"AgentIsInLava\" = AgentIsInLava;") | ||||
|  |     f.close() | ||||
|  |      | ||||
|  |      | ||||
|  |     program = stormpy.parse_prism_program(prism_path) | ||||
|  |     formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" | ||||
|  |      | ||||
|  |     formulas = stormpy.parse_properties_for_prism_program(formula_str, program) | ||||
|  |     options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) | ||||
|  |     options.set_build_state_valuations(True) | ||||
|  |     options.set_build_choice_labels(True) | ||||
|  |     options.set_build_all_labels() | ||||
|  |     model = stormpy.build_sparse_model_with_options(program, options) | ||||
|  |      | ||||
|  |     shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)  | ||||
|  |     result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) | ||||
|  |      | ||||
|  |     assert result.has_scheduler | ||||
|  |     assert result.has_shield | ||||
|  |     shield = result.shield | ||||
|  | 
 | ||||
|  |     stormpy.shields.export_shield(model, shield,"Grid.shield") | ||||
|  |      | ||||
|  |     return shield.construct(), model | ||||
|  | 
 | ||||
|  | def export_grid_to_text(env, grid_file): | ||||
|  |     f = open(grid_file, "w") | ||||
|  |     print(env) | ||||
|  |     f.write(env.printGrid(init=True)) | ||||
|  |     # f.write(env.pprint_grid()) | ||||
|  |     f.close() | ||||
|  | 
 | ||||
|  | def create_environment(args): | ||||
|  |     env_id= args.env | ||||
|  |     env = gym.make(env_id) | ||||
|  |     env.reset() | ||||
|  |     return env | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def main(): | ||||
|  |     args = parse_arguments(argparse) | ||||
|  | 
 | ||||
|  |     env = create_environment(args) | ||||
|  |     ray.init(num_cpus=3) | ||||
|  | 
 | ||||
|  |     # print(env.pprint_grid()) | ||||
|  |     # print(env.printGrid(init=False)) | ||||
|  |      | ||||
|  |     grid_file = args.grid_path | ||||
|  |     export_grid_to_text(env, grid_file) | ||||
|  |      | ||||
|  |     prism_path = args.prism_path | ||||
|  |     shield, model = create_shield(grid_file, prism_path) | ||||
|  |      | ||||
|  |     for state_id in model.states: | ||||
|  |         choices = shield.get_choice(state_id) | ||||
|  |         print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") | ||||
|  |          | ||||
|  |     env_name = "mini-grid" | ||||
|  |     register_env(env_name, env_creater) | ||||
|  |      | ||||
|  |    | ||||
|  |     algo =( | ||||
|  |         PPOConfig() | ||||
|  |         .rollouts(num_rollout_workers=1) | ||||
|  |         .resources(num_gpus=0) | ||||
|  |         .environment(env="mini-grid") | ||||
|  |         .framework("torch") | ||||
|  |         .callbacks(MyCallbacks) | ||||
|  |         .training(model={ | ||||
|  |             "fcnet_hiddens": [256,256], | ||||
|  |             "fcnet_activation": "relu", | ||||
|  |              | ||||
|  |         }) | ||||
|  |         .build() | ||||
|  |     ) | ||||
|  |     episode_reward = 0 | ||||
|  |     terminated = truncated = False | ||||
|  |     obs, info = env.reset() | ||||
|  |      | ||||
|  |     # while not terminated and not truncated: | ||||
|  |     #     action = algo.compute_single_action(obs) | ||||
|  |     #     obs, reward, terminated, truncated = env.step(action) | ||||
|  |      | ||||
|  |     for i in range(30): | ||||
|  |         result = algo.train() | ||||
|  |         print(pretty_print(result)) | ||||
|  | 
 | ||||
|  |         if i % 5 == 0: | ||||
|  |             checkpoint_dir = algo.save() | ||||
|  |             print(f"Checkpoint saved in directory {checkpoint_dir}") | ||||
|  | 
 | ||||
|  |     | ||||
|  | 
 | ||||
|  |     ray.shutdown() | ||||
|  | 
 | ||||
|  | if __name__ == '__main__': | ||||
|  |     main() | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue