Browse Source

renaming / shield handling changes

refactoring
Thomas Knoll 1 year ago
parent
commit
f3747a1479
  1. 112
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/12_basic_training.py
  3. 98
      examples/shields/rl/13_minigridsb.py
  4. 81
      examples/shields/rl/ShieldHandlers.py
  5. 3
      examples/shields/rl/TorchActionMaskModel.py
  6. 97
      examples/shields/rl/Wrappers.py
  7. 77
      examples/shields/rl/helpers.py

112
examples/shields/rl/11_minigridrl.py

@ -1,76 +1,78 @@
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
# from typing import Dict
# from ray.rllib.env.base_env import BaseEnv
# from ray.rllib.evaluation import RolloutWorker
# from ray.rllib.evaluation.episode import Episode
# from ray.rllib.evaluation.episode_v2 import EpisodeV2
# from ray.rllib.policy import Policy
# from ray.rllib.utils.typing import PolicyID
import gymnasium as gym
import minigrid
import numpy as np
# import numpy as np
import ray
# import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
from ray.tune.logger import TBXLogger
# class MyCallbacks(DefaultCallbacks):
# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# episode.user_data["count"] = 0
# # print("On episode start print")
# # print(env.printGrid())
# # print(worker)
# # print(env.action_space.n)
# # print(env.actions)
# # print(env.mission)
# # print(env.observation_space)
# # img = env.get_frame()
# # plt.imshow(img)
# # plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# episode.user_data["count"] = episode.user_data["count"] + 1
# env = base_env.get_sub_environments()[0]
# # print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# #print("On episode end print")
# #print(env.printGrid())
def env_creater_custom(config):
framestack = config.get("framestack", 4)
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridEnvWrapper(env, args=args)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotWrapper(env,
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
@ -80,32 +82,32 @@ def env_creater_custom(config):
def register_custom_minigrid_env(args):
def register_minigrid_shielding_env(args):
env_name = "mini-grid"
register_env(env_name, env_creater_custom)
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"pa_model",
"shielding_model",
TorchActionMaskModel
)
def ppo(args):
register_custom_minigrid_env(args)
register_minigrid_shielding_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"name": args.env, "args": args})
.framework("torch")
.callbacks(MyCallbacks)
.framework("torch")
#.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"type": TBXLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking}
}))
@ -114,6 +116,8 @@ def ppo(args):
config.build()
)
algo.eva
for i in range(args.iterations):
result = algo.train()
print(pretty_print(result))
@ -124,7 +128,7 @@ def ppo(args):
def dqn(args):
register_custom_minigrid_env(args)
register_minigrid_shielding_env(args)
config = DQNConfig()
@ -132,14 +136,14 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.framework("torch")
config = config.callbacks(MyCallbacks)
#config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"type": TBXLogger,
"logdir": create_log_dir(args)
})
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking}
})

4
examples/shields/rl/12_basic_training.py

@ -45,7 +45,7 @@ import argparse
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from Wrapper import OneHotWrapper
from examples.shields.rl.Wrappers import OneHotShieldingWrapper
torch, nn = try_import_torch()
@ -162,7 +162,7 @@ def env_creater(config):
env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)

98
examples/shields/rl/13_minigridsb.py

@ -1,19 +1,19 @@
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
from gymnasium.spaces import Dict, Box
from minigrid.core.actions import Actions
import numpy as np
import time
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
from Wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None):
@ -22,92 +22,10 @@ class CustomCallback(BaseCallback):
def _on_step(self) -> bool:
#print(self.env.printGrid())
print(self.env.printGrid())
return super()._on_step()
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, args=None, no_masking=False):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.args = args
self.no_masking = no_masking
def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = create_shield_dict(self.env, self.args)
self.keys = keys
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info
def mask_fn(env: gym.Env):
return env.create_action_mask()
@ -118,9 +36,13 @@ def main():
import argparse
args = parse_arguments(argparse)
args.grid_path = F"{args.grid_path}.txt"
args.prism_path = F"{args.prism_path}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking)
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
@ -132,7 +54,7 @@ def main():
model.learn(iterations, callback=callback)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)
#W mean_reward, std_reward = evaluate_policy(model, model.get_env())
vec_env = model.get_env()
obs = vec_env.reset()

81
examples/shields/rl/ShieldHandlers.py

@ -0,0 +1,81 @@
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
from abc import ABC
import os
class ShieldHandler(ABC):
def __init__(self) -> None:
pass
def create_shield(self, **kwargs):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.prism_path = prism_path
self.formula = formula
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
f.write(env.printGrid(init=True))
f.close()
def __create_prism(self):
os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
def __create_shield_dict(self):
program = stormpy.parse_prism_program(self.prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def create_shield(self, **kwargs):
env = kwargs["env"]
self.__export_grid_to_text(env)
self.__create_prism()
return self.__create_shield_dict()

3
examples/shields/rl/MaskModels.py → examples/shields/rl/TorchActionMaskModel.py

@ -11,7 +11,6 @@ torch, nn = try_import_torch()
class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__(
self,
@ -23,7 +22,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
**kwargs,
):
orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config']
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -58,7 +56,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask
# Return masked logits.
return masked_logits, state

97
examples/shields/rl/Wrapper.py → examples/shields/rl/Wrappers.py

@ -7,10 +7,11 @@ from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping, create_shield_dict, extract_keys
from helpers import get_action_index_mapping, extract_keys
from ShieldHandlers import ShieldHandler
class OneHotWrapper(gym.core.ObservationWrapper):
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack):
super().__init__(env)
self.framestack = framestack
@ -80,9 +81,9 @@ class OneHotWrapper(gym.core.ObservationWrapper):
return tmp
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, args=None):
super(MiniGridEnvWrapper, self).__init__(env)
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
{
@ -90,7 +91,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
}
)
self.args = args
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
def create_action_mask(self):
coordinates = self.env.agent_pos
@ -142,7 +145,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
self.shield = create_shield_dict(self.env, self.args)
if self.create_shield_at_reset:
self.shield = self.shield_creator.create_shield(env=self.env)
self.keys = extract_keys(self.env)
mask = self.create_action_mask()
return {
@ -163,3 +169,80 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
return obs, rew, done, truncated, info
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, no_masking=False):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.no_masking = no_masking
def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
# if front_tile is not None and front_tile.type == "key":
# mask[Actions.pickup] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = self.shield_creator.create_shield(env=self.env)
self.keys = keys
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info

77
examples/shields/rl/helpers.py

@ -1,6 +1,5 @@
import minigrid
from minigrid.core.actions import Actions
import gymnasium as gym
from datetime import datetime
@ -14,7 +13,6 @@ import stormpy.logic
import stormpy.examples
import stormpy.examples.files
import os
def extract_keys(env):
@ -66,17 +64,17 @@ def parse_arguments(argparse):
choices=[
"MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-Empty-Random-6x6-v0",
# "MiniGrid-DoorKey-8x8-v0",
# "MiniGrid-LockedRoom-v0",
# "MiniGrid-FourRooms-v0",
# "MiniGrid-LavaGapS7-v0",
# "MiniGrid-SimpleCrossingS9N3-v0",
# "MiniGrid-DoorKey-16x16-v0",
# "MiniGrid-Empty-Random-6x6-v0",
])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--no_masking", default=False)
@ -90,62 +88,3 @@ def parse_arguments(argparse):
args = parser.parse_args()
return args
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
f.close()
def create_shield(grid_to_prism_path, grid_file, prism_path, formula):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
formulas = stormpy.parse_properties_for_prism_program(formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def create_shield_dict(env, args):
grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path, args.formula)
return shield_dict
Loading…
Cancel
Save