From f3747a1479080288d844a2239014bd450a9025cb Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Wed, 6 Sep 2023 10:18:55 +0200 Subject: [PATCH] renaming / shield handling changes --- examples/shields/rl/11_minigridrl.py | 112 +++++++++--------- examples/shields/rl/12_basic_training.py | 4 +- examples/shields/rl/13_minigridsb.py | 98 ++------------- examples/shields/rl/ShieldHandlers.py | 81 +++++++++++++ ...{MaskModels.py => TorchActionMaskModel.py} | 3 - .../shields/rl/{Wrapper.py => Wrappers.py} | 97 +++++++++++++-- examples/shields/rl/helpers.py | 77 ++---------- 7 files changed, 249 insertions(+), 223 deletions(-) create mode 100644 examples/shields/rl/ShieldHandlers.py rename examples/shields/rl/{MaskModels.py => TorchActionMaskModel.py} (94%) rename examples/shields/rl/{Wrapper.py => Wrappers.py} (61%) diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 4162189..4ce06a7 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -1,76 +1,78 @@ -from typing import Dict -from ray.rllib.env.base_env import BaseEnv -from ray.rllib.evaluation import RolloutWorker -from ray.rllib.evaluation.episode import Episode -from ray.rllib.evaluation.episode_v2 import EpisodeV2 -from ray.rllib.policy import Policy -from ray.rllib.utils.typing import PolicyID -from ray.rllib.algorithms.algorithm import Algorithm +# from typing import Dict +# from ray.rllib.env.base_env import BaseEnv +# from ray.rllib.evaluation import RolloutWorker +# from ray.rllib.evaluation.episode import Episode +# from ray.rllib.evaluation.episode_v2 import EpisodeV2 +# from ray.rllib.policy import Policy +# from ray.rllib.utils.typing import PolicyID import gymnasium as gym import minigrid -import numpy as np +# import numpy as np -import ray +# import ray from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig -from ray.rllib.algorithms.callbacks import DefaultCallbacks +# from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.models import ModelCatalog -from ray.rllib.utils.torch_utils import FLOAT_MIN -from MaskModels import TorchActionMaskModel -from Wrapper import OneHotWrapper, MiniGridEnvWrapper +from TorchActionMaskModel import TorchActionMaskModel +from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from helpers import parse_arguments, create_log_dir +from ShieldHandlers import MiniGridShieldHandler import matplotlib.pyplot as plt -class MyCallbacks(DefaultCallbacks): - def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - episode.user_data["count"] = 0 - # print("On episode start print") - # print(env.printGrid()) - # print(worker) - # print(env.action_space.n) - # print(env.actions) - # print(env.mission) - # print(env.observation_space) - # img = env.get_frame() - # plt.imshow(img) - # plt.show() +from ray.tune.logger import TBXLogger + +# class MyCallbacks(DefaultCallbacks): +# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: +# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") +# env = base_env.get_sub_environments()[0] +# episode.user_data["count"] = 0 +# # print("On episode start print") +# # print(env.printGrid()) +# # print(worker) +# # print(env.action_space.n) +# # print(env.actions) +# # print(env.mission) +# # print(env.observation_space) +# # img = env.get_frame() +# # plt.imshow(img) +# # plt.show() - def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - episode.user_data["count"] = episode.user_data["count"] + 1 - env = base_env.get_sub_environments()[0] - # print(env.printGrid()) +# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: +# episode.user_data["count"] = episode.user_data["count"] + 1 +# env = base_env.get_sub_environments()[0] +# # print(env.printGrid()) - def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - #print("On episode end print") - #print(env.printGrid()) +# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: +# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") +# env = base_env.get_sub_environments()[0] +# #print("On episode end print") +# #print(env.printGrid()) -def env_creater_custom(config): - framestack = config.get("framestack", 4) +def shielding_env_creater(config): name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) args = config.get("args", None) args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + env = gym.make(name) - env = MiniGridEnvWrapper(env, args=args) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) - env = OneHotWrapper(env, + env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, framestack=framestack ) @@ -80,32 +82,32 @@ def env_creater_custom(config): -def register_custom_minigrid_env(args): +def register_minigrid_shielding_env(args): env_name = "mini-grid" - register_env(env_name, env_creater_custom) + register_env(env_name, shielding_env_creater) ModelCatalog.register_custom_model( - "pa_model", + "shielding_model", TorchActionMaskModel ) def ppo(args): - register_custom_minigrid_env(args) + register_minigrid_shielding_env(args) config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) .environment(env="mini-grid", env_config={"name": args.env, "args": args}) - .framework("torch") - .callbacks(MyCallbacks) + .framework("torch") + #.callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ - "type": "ray.tune.logger.TBXLogger", + "type": TBXLogger, "logdir": create_log_dir(args) }) .training(_enable_learner_api=False ,model={ - "custom_model": "pa_model", + "custom_model": "shielding_model", "custom_model_config" : {"no_masking": args.no_masking} })) @@ -114,6 +116,8 @@ def ppo(args): config.build() ) + algo.eva + for i in range(args.iterations): result = algo.train() print(pretty_print(result)) @@ -124,7 +128,7 @@ def ppo(args): def dqn(args): - register_custom_minigrid_env(args) + register_minigrid_shielding_env(args) config = DQNConfig() @@ -132,14 +136,14 @@ def dqn(args): config = config.rollouts(num_rollout_workers=args.workers) config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.framework("torch") - config = config.callbacks(MyCallbacks) + #config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ - "type": "ray.tune.logger.TBXLogger", + "type": TBXLogger, "logdir": create_log_dir(args) }) config = config.training(hiddens=[], dueling=False, model={ - "custom_model": "pa_model", + "custom_model": "shielding_model", "custom_model_config" : {"no_masking": args.no_masking} }) diff --git a/examples/shields/rl/12_basic_training.py b/examples/shields/rl/12_basic_training.py index 72bb868..40759b2 100644 --- a/examples/shields/rl/12_basic_training.py +++ b/examples/shields/rl/12_basic_training.py @@ -45,7 +45,7 @@ import argparse from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.framework import try_import_tf, try_import_torch -from Wrapper import OneHotWrapper +from examples.shields.rl.Wrappers import OneHotShieldingWrapper torch, nn = try_import_torch() @@ -162,7 +162,7 @@ def env_creater(config): env = gym.make(name) # env = minigrid.wrappers.RGBImgPartialObsWrapper(env) env = minigrid.wrappers.ImgObsWrapper(env) - env = OneHotWrapper(env, + env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, framestack=framestack ) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 7959319..8d0c86f 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,19 +1,19 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy -from sb3_contrib.common.maskable.utils import get_action_masks from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.wrappers import ActionMasker from stable_baselines3.common.callbacks import BaseCallback import gymnasium as gym -from gymnasium.spaces import Dict, Box from minigrid.core.actions import Actions import numpy as np import time -from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir +from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir +from ShieldHandlers import MiniGridShieldHandler +from Wrappers import MiniGridSbShieldingWrapper class CustomCallback(BaseCallback): def __init__(self, verbose: int = 0, env=None): @@ -22,92 +22,10 @@ class CustomCallback(BaseCallback): def _on_step(self) -> bool: - #print(self.env.printGrid()) + print(self.env.printGrid()) return super()._on_step() -class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env, args=None, no_masking=False): - super(MiniGridEnvWrapper, self).__init__(env) - self.max_available_actions = env.action_space.n - self.observation_space = env.observation_space.spaces["image"] - - self.args = args - self.no_masking = no_masking - - def create_action_mask(self): - if self.no_masking: - return np.array([1.0] * self.max_available_actions, dtype=np.int8) - - - coordinates = self.env.agent_pos - view_direction = self.env.agent_dir - - key_text = "" - - # only support one key for now - if self.keys: - key_text = F"!Agent_has_{self.keys[0]}_key\t& " - - - if self.env.carrying and self.env.carrying.type == "key": - key_text = F"Agent_has_{self.env.carrying.color}_key\t& " - - #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") - cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" - - allowed_actions = [] - - - # Create the mask - # If shield restricts action mask only valid with 1.0 - # else set all actions as valid - mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) - - if cur_pos_str in self.shield and self.shield[cur_pos_str]: - allowed_actions = self.shield[cur_pos_str] - for allowed_action in allowed_actions: - index = get_action_index_mapping(allowed_action[1]) - if index is None: - assert(False) - mask[index] = 1.0 - else: - # print(F"Not in shield {cur_pos_str}") - for index, x in enumerate(mask): - mask[index] = 1.0 - - front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) - - if front_tile is not None and front_tile.type == "key": - mask[Actions.pickup] = 1.0 - - if self.env.carrying: - mask[Actions.drop] = 1.0 - - if front_tile and front_tile.type == "door": - mask[Actions.toggle] = 1.0 - - - - return mask - - def reset(self, *, seed=None, options=None): - obs, infos = self.env.reset(seed=seed, options=options) - - keys = extract_keys(self.env) - shield = create_shield_dict(self.env, self.args) - - self.keys = keys - self.shield = shield - return obs["image"], infos - - def step(self, action): - orig_obs, rew, done, truncated, info = self.env.step(action) - obs = orig_obs["image"] - - return obs, rew, done, truncated, info - - def mask_fn(env: gym.Env): return env.create_action_mask() @@ -118,9 +36,13 @@ def main(): import argparse args = parse_arguments(argparse) + args.grid_path = F"{args.grid_path}.txt" + args.prism_path = F"{args.prism_path}.prism" + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) env = gym.make(args.env, render_mode="rgb_array") - env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking) + env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking) env = ActionMasker(env, mask_fn) callback = CustomCallback(1, env) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) @@ -132,7 +54,7 @@ def main(): model.learn(iterations, callback=callback) - mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10) + #W mean_reward, std_reward = evaluate_policy(model, model.get_env()) vec_env = model.get_env() obs = vec_env.reset() diff --git a/examples/shields/rl/ShieldHandlers.py b/examples/shields/rl/ShieldHandlers.py new file mode 100644 index 0000000..0ec7936 --- /dev/null +++ b/examples/shields/rl/ShieldHandlers.py @@ -0,0 +1,81 @@ +import stormpy +import stormpy.core +import stormpy.simulator + +import stormpy.shields +import stormpy.logic + +import stormpy.examples +import stormpy.examples.files + +from abc import ABC + +import os + +class ShieldHandler(ABC): + def __init__(self) -> None: + pass + def create_shield(self, **kwargs): + pass + +class MiniGridShieldHandler(ShieldHandler): + def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None: + self.grid_file = grid_file + self.grid_to_prism_path = grid_to_prism_path + self.prism_path = prism_path + self.formula = formula + + def __export_grid_to_text(self, env): + f = open(self.grid_file, "w") + f.write(env.printGrid(init=True)) + f.close() + + + def __create_prism(self): + os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}") + + f = open(self.prism_path, "a") + f.write("label \"AgentIsInLava\" = AgentIsInLava;") + f.close() + + def __create_shield_dict(self): + program = stormpy.parse_prism_program(self.prism_path) + shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) + + formulas = stormpy.parse_properties_for_prism_program(self.formula, program) + options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) + options.set_build_state_valuations(True) + options.set_build_choice_labels(True) + options.set_build_all_labels() + model = stormpy.build_sparse_model_with_options(program, options) + + result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) + + assert result.has_scheduler + assert result.has_shield + shield = result.shield + + action_dictionary = {} + shield_scheduler = shield.construct() + + for stateID in model.states: + choice = shield_scheduler.get_choice(stateID) + choices = choice.choice_map + state_valuation = model.state_valuations.get_string(stateID) + + actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] + + action_dictionary[state_valuation] = actions_to_be_executed + + stormpy.shields.export_shield(model, shield, "Grid.shield") + + return action_dictionary + + + def create_shield(self, **kwargs): + env = kwargs["env"] + self.__export_grid_to_text(env) + self.__create_prism() + + return self.__create_shield_dict() + \ No newline at end of file diff --git a/examples/shields/rl/MaskModels.py b/examples/shields/rl/TorchActionMaskModel.py similarity index 94% rename from examples/shields/rl/MaskModels.py rename to examples/shields/rl/TorchActionMaskModel.py index e882a51..42b6805 100644 --- a/examples/shields/rl/MaskModels.py +++ b/examples/shields/rl/TorchActionMaskModel.py @@ -11,7 +11,6 @@ torch, nn = try_import_torch() class TorchActionMaskModel(TorchModelV2, nn.Module): - """PyTorch version of above ActionMaskingModel.""" def __init__( self, @@ -23,7 +22,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): **kwargs, ): orig_space = getattr(obs_space, "original_space", obs_space) - custom_config = model_config['custom_model_config'] TorchModelV2.__init__( self, obs_space, action_space, num_outputs, model_config, name, **kwargs @@ -58,7 +56,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) masked_logits = logits + inf_mask - # Return masked logits. return masked_logits, state diff --git a/examples/shields/rl/Wrapper.py b/examples/shields/rl/Wrappers.py similarity index 61% rename from examples/shields/rl/Wrapper.py rename to examples/shields/rl/Wrappers.py index 390974b..ef761fa 100644 --- a/examples/shields/rl/Wrapper.py +++ b/examples/shields/rl/Wrappers.py @@ -7,10 +7,11 @@ from gymnasium.spaces import Dict, Box from collections import deque from ray.rllib.utils.numpy import one_hot -from helpers import get_action_index_mapping, create_shield_dict, extract_keys +from helpers import get_action_index_mapping, extract_keys +from ShieldHandlers import ShieldHandler -class OneHotWrapper(gym.core.ObservationWrapper): +class OneHotShieldingWrapper(gym.core.ObservationWrapper): def __init__(self, env, vector_index, framestack): super().__init__(env) self.framestack = framestack @@ -80,9 +81,9 @@ class OneHotWrapper(gym.core.ObservationWrapper): return tmp -class MiniGridEnvWrapper(gym.core.Wrapper): - def __init__(self, env, args=None): - super(MiniGridEnvWrapper, self).__init__(env) +class MiniGridShieldingWrapper(gym.core.Wrapper): + def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True): + super(MiniGridShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( { @@ -90,7 +91,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper): "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), } ) - self.args = args + self.shield_creator = shield_creator + self.create_shield_at_reset = create_shield_at_reset + self.shield = shield_creator.create_shield(env=self.env) def create_action_mask(self): coordinates = self.env.agent_pos @@ -142,7 +145,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - self.shield = create_shield_dict(self.env, self.args) + + if self.create_shield_at_reset: + self.shield = self.shield_creator.create_shield(env=self.env) + self.keys = extract_keys(self.env) mask = self.create_action_mask() return { @@ -163,3 +169,80 @@ class MiniGridEnvWrapper(gym.core.Wrapper): return obs, rew, done, truncated, info + +class MiniGridSbShieldingWrapper(gym.core.Wrapper): + def __init__(self, env, shield_creator : ShieldHandler, no_masking=False): + super(MiniGridSbShieldingWrapper, self).__init__(env) + self.max_available_actions = env.action_space.n + self.observation_space = env.observation_space.spaces["image"] + + self.shield_creator = shield_creator + self.no_masking = no_masking + + def create_action_mask(self): + if self.no_masking: + return np.array([1.0] * self.max_available_actions, dtype=np.int8) + + coordinates = self.env.agent_pos + view_direction = self.env.agent_dir + + key_text = "" + + # only support one key for now + if self.keys: + key_text = F"!Agent_has_{self.keys[0]}_key\t& " + + + if self.env.carrying and self.env.carrying.type == "key": + key_text = F"Agent_has_{self.env.carrying.color}_key\t& " + + #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") + cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + allowed_actions = [] + + # Create the mask + # If shield restricts action mask only valid with 1.0 + # else set all actions as valid + mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) + + if cur_pos_str in self.shield and self.shield[cur_pos_str]: + allowed_actions = self.shield[cur_pos_str] + for allowed_action in allowed_actions: + index = get_action_index_mapping(allowed_action[1]) + if index is None: + assert(False) + mask[index] = 1.0 + else: + # print(F"Not in shield {cur_pos_str}") + for index, x in enumerate(mask): + mask[index] = 1.0 + + front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) + + # if front_tile is not None and front_tile.type == "key": + # mask[Actions.pickup] = 1.0 + + # if self.env.carrying: + # mask[Actions.drop] = 1.0 + + if front_tile and front_tile.type == "door": + mask[Actions.toggle] = 1.0 + + return mask + + def reset(self, *, seed=None, options=None): + obs, infos = self.env.reset(seed=seed, options=options) + + keys = extract_keys(self.env) + shield = self.shield_creator.create_shield(env=self.env) + + self.keys = keys + self.shield = shield + return obs["image"], infos + + def step(self, action): + orig_obs, rew, done, truncated, info = self.env.step(action) + obs = orig_obs["image"] + + return obs, rew, done, truncated, info + diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 9106745..59af016 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -1,6 +1,5 @@ import minigrid from minigrid.core.actions import Actions -import gymnasium as gym from datetime import datetime @@ -14,7 +13,6 @@ import stormpy.logic import stormpy.examples import stormpy.examples.files -import os def extract_keys(env): @@ -66,17 +64,17 @@ def parse_arguments(argparse): choices=[ "MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N3-v0", - "MiniGrid-DoorKey-8x8-v0", - "MiniGrid-LockedRoom-v0", - "MiniGrid-FourRooms-v0", - "MiniGrid-LavaGapS7-v0", - "MiniGrid-SimpleCrossingS9N3-v0", - "MiniGrid-DoorKey-16x16-v0", - "MiniGrid-Empty-Random-6x6-v0", + # "MiniGrid-DoorKey-8x8-v0", + # "MiniGrid-LockedRoom-v0", + # "MiniGrid-FourRooms-v0", + # "MiniGrid-LavaGapS7-v0", + # "MiniGrid-SimpleCrossingS9N3-v0", + # "MiniGrid-DoorKey-16x16-v0", + # "MiniGrid-Empty-Random-6x6-v0", ]) # parser.add_argument("--seed", type=int, help="seed for environment", default=None) - parser.add_argument("--grid_to_prism_path", default="./main") + parser.add_argument("--grid_to_prism_binary_path", default="./main") parser.add_argument("--grid_path", default="grid") parser.add_argument("--prism_path", default="grid") parser.add_argument("--no_masking", default=False) @@ -90,62 +88,3 @@ def parse_arguments(argparse): args = parser.parse_args() return args - - -def export_grid_to_text(env, grid_file): - f = open(grid_file, "w") - # print(env) - f.write(env.printGrid(init=True)) - f.close() - - -def create_shield(grid_to_prism_path, grid_file, prism_path, formula): - os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}") - - f = open(prism_path, "a") - f.write("label \"AgentIsInLava\" = AgentIsInLava;") - f.close() - - program = stormpy.parse_prism_program(prism_path) - - shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) - - formulas = stormpy.parse_properties_for_prism_program(formula, program) - options = stormpy.BuilderOptions([p.raw_formula for p in formulas]) - options.set_build_state_valuations(True) - options.set_build_choice_labels(True) - options.set_build_all_labels() - model = stormpy.build_sparse_model_with_options(program, options) - - result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification) - - assert result.has_scheduler - assert result.has_shield - shield = result.shield - - action_dictionary = {} - shield_scheduler = shield.construct() - - for stateID in model.states: - choice = shield_scheduler.get_choice(stateID) - choices = choice.choice_map - state_valuation = model.state_valuations.get_string(stateID) - - actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices] - - action_dictionary[state_valuation] = actions_to_be_executed - - stormpy.shields.export_shield(model, shield, "Grid.shield") - return action_dictionary - - -def create_shield_dict(env, args): - grid_file = args.grid_path - grid_to_prism_path = args.grid_to_prism_path - export_grid_to_text(env, grid_file) - - prism_path = args.prism_path - shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path, args.formula) - - return shield_dict -