2 changed files with 6 additions and 278 deletions
			
			
		| @ -1,275 +0,0 @@ | |||||
| from typing import Dict, Optional, Union |  | ||||
| from ray.rllib.env.base_env import BaseEnv |  | ||||
| from ray.rllib.evaluation import RolloutWorker |  | ||||
| from ray.rllib.evaluation.episode import Episode |  | ||||
| from ray.rllib.evaluation.episode_v2 import EpisodeV2 |  | ||||
| from ray.rllib.policy import Policy |  | ||||
| from ray.rllib.utils.typing import PolicyID |  | ||||
| import stormpy |  | ||||
| import stormpy.core |  | ||||
| import stormpy.simulator |  | ||||
| 
 |  | ||||
| from collections import deque |  | ||||
| 
 |  | ||||
| import stormpy.shields |  | ||||
| import stormpy.logic |  | ||||
| 
 |  | ||||
| import stormpy.examples |  | ||||
| import stormpy.examples.files |  | ||||
| import os |  | ||||
| 
 |  | ||||
| import gymnasium as gym |  | ||||
| 
 |  | ||||
| import minigrid |  | ||||
| import numpy as np |  | ||||
| 
 |  | ||||
| import ray |  | ||||
| from ray.tune import register_env |  | ||||
| from ray.rllib.algorithms.ppo import PPOConfig |  | ||||
| from ray.rllib.utils.test_utils import check_learning_achieved, framework_iterator |  | ||||
| from ray import tune, air |  | ||||
| from ray.rllib.algorithms.callbacks import DefaultCallbacks |  | ||||
| from ray.tune.logger import pretty_print |  | ||||
| from ray.rllib.utils.numpy import one_hot |  | ||||
| from ray.rllib.algorithms import ppo |  | ||||
| 
 |  | ||||
| from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC |  | ||||
| from ray.rllib.models.tf.fcnet import FullyConnectedNetwork |  | ||||
| from ray.rllib.utils.torch_utils import FLOAT_MIN |  | ||||
| 
 |  | ||||
| from ray.rllib.models.preprocessors import get_preprocessor |  | ||||
| 
 |  | ||||
| import matplotlib.pyplot as plt |  | ||||
| 
 |  | ||||
| import argparse |  | ||||
| from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 |  | ||||
| from ray.rllib.utils.framework import try_import_tf, try_import_torch |  | ||||
| 
 |  | ||||
| from examples.shields.rl.Wrappers import OneHotShieldingWrapper |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| torch, nn = try_import_torch() |  | ||||
| 
 |  | ||||
| class TorchActionMaskModel(TorchModelV2, nn.Module): |  | ||||
|     """PyTorch version of above ActionMaskingModel.""" |  | ||||
| 
 |  | ||||
|     def __init__( |  | ||||
|         self, |  | ||||
|         obs_space, |  | ||||
|         action_space, |  | ||||
|         num_outputs, |  | ||||
|         model_config, |  | ||||
|         name, |  | ||||
|         **kwargs, |  | ||||
|     ): |  | ||||
|         orig_space = getattr(obs_space, "original_space", obs_space) |  | ||||
|         assert ( |  | ||||
|             isinstance(orig_space, Dict) |  | ||||
|             and "action_mask" in orig_space.spaces |  | ||||
|             and "observations" in orig_space.spaces |  | ||||
|         ) |  | ||||
| 
 |  | ||||
|         TorchModelV2.__init__( |  | ||||
|             self, obs_space, action_space, num_outputs, model_config, name, **kwargs |  | ||||
|         ) |  | ||||
|         nn.Module.__init__(self) |  | ||||
| 
 |  | ||||
|         self.internal_model = TorchFC( |  | ||||
|             orig_space["observations"], |  | ||||
|             action_space, |  | ||||
|             num_outputs, |  | ||||
|             model_config, |  | ||||
|             name + "_internal", |  | ||||
|         ) |  | ||||
| 
 |  | ||||
|         # disable action masking --> will likely lead to invalid actions |  | ||||
|         self.no_masking = False |  | ||||
|         if "no_masking" in model_config["custom_model_config"]: |  | ||||
|             self.no_masking = model_config["custom_model_config"]["no_masking"] |  | ||||
| 
 |  | ||||
|     def forward(self, input_dict, state, seq_lens): |  | ||||
|         # Extract the available actions tensor from the observation. |  | ||||
|         action_mask = input_dict["obs"]["action_mask"] |  | ||||
| 
 |  | ||||
|         # Compute the unmasked logits. |  | ||||
|         logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]}) |  | ||||
| 
 |  | ||||
|         # If action masking is disabled, directly return unmasked logits |  | ||||
|         if self.no_masking: |  | ||||
|             return logits, state |  | ||||
| 
 |  | ||||
|         # Convert action_mask into a [0.0 || -inf]-type mask. |  | ||||
|         inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) |  | ||||
|         masked_logits = logits + inf_mask |  | ||||
| 
 |  | ||||
|         # Return masked logits. |  | ||||
|         return masked_logits, state |  | ||||
| 
 |  | ||||
|     def value_function(self): |  | ||||
|         return self.internal_model.value_function() |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| class MyCallbacks(DefaultCallbacks): |  | ||||
|     def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |  | ||||
|         # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") |  | ||||
|         env = base_env.get_sub_environments()[0] |  | ||||
|         episode.user_data["count"] = 0 |  | ||||
|         # print(env.printGrid()) |  | ||||
|         # print(env.action_space.n) |  | ||||
|         # print(env.actions) |  | ||||
|         # print(env.mission) |  | ||||
|         # print(env.observation_space) |  | ||||
|         # img = env.get_frame() |  | ||||
|         # plt.imshow(img) |  | ||||
|         # plt.show() |  | ||||
|         |  | ||||
|     def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: |  | ||||
|          episode.user_data["count"] = episode.user_data["count"] + 1 |  | ||||
|          env = base_env.get_sub_environments()[0] |  | ||||
|         #  print(env.env.env.printGrid()) |  | ||||
|      |  | ||||
|     def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: |  | ||||
|         # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") |  | ||||
|         env = base_env.get_sub_environments()[0] |  | ||||
|         # print(env.env.env.printGrid()) |  | ||||
|         # print(episode.user_data["count"]) |  | ||||
|          |  | ||||
|      |  | ||||
|      |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| def parse_arguments(argparse): |  | ||||
|     parser = argparse.ArgumentParser() |  | ||||
|     # parser.add_argument("--env", help="gym environment to load", default="MiniGrid-Empty-8x8-v0") |  | ||||
|     parser.add_argument("--env", help="gym environment to load", default="MiniGrid-LavaCrossingS9N1-v0") |  | ||||
|     parser.add_argument("--seed", type=int, help="seed for environment", default=1) |  | ||||
|     parser.add_argument("--tile_size", type=int, help="size at which to render tiles", default=32) |  | ||||
|     parser.add_argument("--agent_view", default=False, action="store_true", help="draw the agent sees") |  | ||||
|     parser.add_argument("--grid_path", default="Grid.txt") |  | ||||
|     parser.add_argument("--prism_path", default="Grid.PRISM") |  | ||||
|      |  | ||||
|     args = parser.parse_args() |  | ||||
|      |  | ||||
|     return args |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| def env_creater(config): |  | ||||
|     name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") |  | ||||
|     # name = config.get("name", "MiniGrid-Empty-8x8-v0") |  | ||||
|     framestack = config.get("framestack", 4) |  | ||||
|      |  | ||||
|     env = gym.make(name) |  | ||||
|     # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) |  | ||||
|     env = minigrid.wrappers.ImgObsWrapper(env) |  | ||||
|     env = OneHotShieldingWrapper(env, |  | ||||
|                         config.vector_index if hasattr(config, "vector_index") else 0, |  | ||||
|                         framestack=framestack |  | ||||
|                         ) |  | ||||
|        |  | ||||
| 
 |  | ||||
|     return env |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| def create_shield(grid_file, prism_path): |  | ||||
|     os.system(F"/home/tknoll/Documents/main -v 'agent' -i {grid_file} -o {prism_path}") |  | ||||
|      |  | ||||
|     f = open(prism_path, "a") |  | ||||
|     f.write("label \"AgentIsInLava\" = AgentIsInLava;") |  | ||||
|     f.close() |  | ||||
|      |  | ||||
|      |  | ||||
|     program = stormpy.parse_prism_program(prism_path) |  | ||||
|     formula_str = "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]" |  | ||||
|      |  | ||||
|     formulas = stormpy.parse_properties_for_prism_program(formula_str, program) |  | ||||
|     options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) |  | ||||
|     options.set_build_state_valuations(True) |  | ||||
|     options.set_build_choice_labels(True) |  | ||||
|     options.set_build_all_labels() |  | ||||
|     model = stormpy.build_sparse_model_with_options(program, options) |  | ||||
|      |  | ||||
|     shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)  |  | ||||
|     result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) |  | ||||
|      |  | ||||
|     assert result.has_scheduler |  | ||||
|     assert result.has_shield |  | ||||
|     shield = result.shield |  | ||||
| 
 |  | ||||
|     stormpy.shields.export_shield(model, shield,"Grid.shield") |  | ||||
|      |  | ||||
|     return shield.construct(), model |  | ||||
| 
 |  | ||||
| def export_grid_to_text(env, grid_file): |  | ||||
|     f = open(grid_file, "w") |  | ||||
|     print(env) |  | ||||
|     f.write(env.printGrid(init=True)) |  | ||||
|     # f.write(env.pprint_grid()) |  | ||||
|     f.close() |  | ||||
| 
 |  | ||||
| def create_environment(args): |  | ||||
|     env_id= args.env |  | ||||
|     env = gym.make(env_id) |  | ||||
|     env.reset() |  | ||||
|     return env |  | ||||
| 
 |  | ||||
| 
 |  | ||||
| def main(): |  | ||||
|     args = parse_arguments(argparse) |  | ||||
| 
 |  | ||||
|     env = create_environment(args) |  | ||||
|     ray.init(num_cpus=3) |  | ||||
| 
 |  | ||||
|     # print(env.pprint_grid()) |  | ||||
|     # print(env.printGrid(init=False)) |  | ||||
|      |  | ||||
|     grid_file = args.grid_path |  | ||||
|     export_grid_to_text(env, grid_file) |  | ||||
|      |  | ||||
|     prism_path = args.prism_path |  | ||||
|     shield, model = create_shield(grid_file, prism_path) |  | ||||
|      |  | ||||
|     for state_id in model.states: |  | ||||
|         choices = shield.get_choice(state_id) |  | ||||
|         print(F"Allowed choices in state {state_id}, are {choices.choice_map} ") |  | ||||
|          |  | ||||
|     env_name = "mini-grid" |  | ||||
|     register_env(env_name, env_creater) |  | ||||
|      |  | ||||
|    |  | ||||
|     algo =( |  | ||||
|         PPOConfig() |  | ||||
|         .rollouts(num_rollout_workers=1) |  | ||||
|         .resources(num_gpus=0) |  | ||||
|         .environment(env="mini-grid") |  | ||||
|         .framework("torch") |  | ||||
|         .callbacks(MyCallbacks) |  | ||||
|         .training(model={ |  | ||||
|             "fcnet_hiddens": [256,256], |  | ||||
|             "fcnet_activation": "relu", |  | ||||
|              |  | ||||
|         }) |  | ||||
|         .build() |  | ||||
|     ) |  | ||||
|     episode_reward = 0 |  | ||||
|     terminated = truncated = False |  | ||||
|     obs, info = env.reset() |  | ||||
|      |  | ||||
|     # while not terminated and not truncated: |  | ||||
|     #     action = algo.compute_single_action(obs) |  | ||||
|     #     obs, reward, terminated, truncated = env.step(action) |  | ||||
|      |  | ||||
|     for i in range(30): |  | ||||
|         result = algo.train() |  | ||||
|         print(pretty_print(result)) |  | ||||
| 
 |  | ||||
|         if i % 5 == 0: |  | ||||
|             checkpoint_dir = algo.save() |  | ||||
|             print(f"Checkpoint saved in directory {checkpoint_dir}") |  | ||||
| 
 |  | ||||
|     |  | ||||
| 
 |  | ||||
|     ray.shutdown() |  | ||||
| 
 |  | ||||
| if __name__ == '__main__': |  | ||||
|     main() |  | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue