Browse Source

renaming / shield handling changes

refactoring
Thomas Knoll 2 years ago
parent
commit
f3747a1479
  1. 122
      examples/shields/rl/11_minigridrl.py
  2. 4
      examples/shields/rl/12_basic_training.py
  3. 98
      examples/shields/rl/13_minigridsb.py
  4. 81
      examples/shields/rl/ShieldHandlers.py
  5. 3
      examples/shields/rl/TorchActionMaskModel.py
  6. 97
      examples/shields/rl/Wrappers.py
  7. 77
      examples/shields/rl/helpers.py

122
examples/shields/rl/11_minigridrl.py

@ -1,76 +1,78 @@
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
# from typing import Dict
# from ray.rllib.env.base_env import BaseEnv
# from ray.rllib.evaluation import RolloutWorker
# from ray.rllib.evaluation.episode import Episode
# from ray.rllib.evaluation.episode_v2 import EpisodeV2
# from ray.rllib.policy import Policy
# from ray.rllib.utils.typing import PolicyID
import gymnasium as gym import gymnasium as gym
import minigrid import minigrid
import numpy as np
# import numpy as np
import ray
# import ray
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from ray.rllib.utils.torch_utils import FLOAT_MIN
from MaskModels import TorchActionMaskModel
from Wrapper import OneHotWrapper, MiniGridEnvWrapper
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir from helpers import parse_arguments, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
def env_creater_custom(config):
framestack = config.get("framestack", 4)
from ray.tune.logger import TBXLogger
# class MyCallbacks(DefaultCallbacks):
# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# episode.user_data["count"] = 0
# # print("On episode start print")
# # print(env.printGrid())
# # print(worker)
# # print(env.action_space.n)
# # print(env.actions)
# # print(env.mission)
# # print(env.observation_space)
# # img = env.get_frame()
# # plt.imshow(img)
# # plt.show()
# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# episode.user_data["count"] = episode.user_data["count"] + 1
# env = base_env.get_sub_environments()[0]
# # print(env.printGrid())
# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# #print("On episode end print")
# #print(env.printGrid())
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4) framestack = config.get("framestack", 4)
args = config.get("args", None) args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name) env = gym.make(name)
env = MiniGridEnvWrapper(env, args=args)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator)
# env = minigrid.wrappers.ImgObsWrapper(env) # env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env) # env = ImgObsWrapper(env)
env = OneHotWrapper(env,
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0, config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack framestack=framestack
) )
@ -80,32 +82,32 @@ def env_creater_custom(config):
def register_custom_minigrid_env(args):
def register_minigrid_shielding_env(args):
env_name = "mini-grid" env_name = "mini-grid"
register_env(env_name, env_creater_custom)
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model( ModelCatalog.register_custom_model(
"pa_model",
"shielding_model",
TorchActionMaskModel TorchActionMaskModel
) )
def ppo(args): def ppo(args):
register_custom_minigrid_env(args)
register_minigrid_shielding_env(args)
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid", env_config={"name": args.env, "args": args}) .environment(env="mini-grid", env_config={"name": args.env, "args": args})
.framework("torch") .framework("torch")
.callbacks(MyCallbacks)
#.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"type": TBXLogger,
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "pa_model",
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking} "custom_model_config" : {"no_masking": args.no_masking}
})) }))
@ -114,6 +116,8 @@ def ppo(args):
config.build() config.build()
) )
algo.eva
for i in range(args.iterations): for i in range(args.iterations):
result = algo.train() result = algo.train()
print(pretty_print(result)) print(pretty_print(result))
@ -124,7 +128,7 @@ def ppo(args):
def dqn(args): def dqn(args):
register_custom_minigrid_env(args)
register_minigrid_shielding_env(args)
config = DQNConfig() config = DQNConfig()
@ -132,14 +136,14 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers) config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
config = config.callbacks(MyCallbacks)
#config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={ config = config.debugging(logger_config={
"type": "ray.tune.logger.TBXLogger",
"type": TBXLogger,
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
config = config.training(hiddens=[], dueling=False, model={ config = config.training(hiddens=[], dueling=False, model={
"custom_model": "pa_model",
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking} "custom_model_config" : {"no_masking": args.no_masking}
}) })

4
examples/shields/rl/12_basic_training.py

@ -45,7 +45,7 @@ import argparse
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.framework import try_import_tf, try_import_torch
from Wrapper import OneHotWrapper
from examples.shields.rl.Wrappers import OneHotShieldingWrapper
torch, nn = try_import_torch() torch, nn = try_import_torch()
@ -162,7 +162,7 @@ def env_creater(config):
env = gym.make(name) env = gym.make(name)
# env = minigrid.wrappers.RGBImgPartialObsWrapper(env) # env = minigrid.wrappers.RGBImgPartialObsWrapper(env)
env = minigrid.wrappers.ImgObsWrapper(env) env = minigrid.wrappers.ImgObsWrapper(env)
env = OneHotWrapper(env,
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0, config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack framestack=framestack
) )

98
examples/shields/rl/13_minigridsb.py

@ -1,19 +1,19 @@
from sb3_contrib import MaskablePPO from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym import gymnasium as gym
from gymnasium.spaces import Dict, Box
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
import numpy as np import numpy as np
import time import time
from helpers import create_shield_dict, parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
from Wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback): class CustomCallback(BaseCallback):
def __init__(self, verbose: int = 0, env=None): def __init__(self, verbose: int = 0, env=None):
@ -22,92 +22,10 @@ class CustomCallback(BaseCallback):
def _on_step(self) -> bool: def _on_step(self) -> bool:
#print(self.env.printGrid())
print(self.env.printGrid())
return super()._on_step() return super()._on_step()
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, args=None, no_masking=False):
super(MiniGridEnvWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.args = args
self.no_masking = no_masking
def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
if front_tile is not None and front_tile.type == "key":
mask[Actions.pickup] = 1.0
if self.env.carrying:
mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = create_shield_dict(self.env, self.args)
self.keys = keys
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info
def mask_fn(env: gym.Env): def mask_fn(env: gym.Env):
return env.create_action_mask() return env.create_action_mask()
@ -118,9 +36,13 @@ def main():
import argparse import argparse
args = parse_arguments(argparse) args = parse_arguments(argparse)
args.grid_path = F"{args.grid_path}.txt"
args.prism_path = F"{args.prism_path}.prism"
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(args.env, render_mode="rgb_array") env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridEnvWrapper(env,args=args, no_masking=args.no_masking)
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking)
env = ActionMasker(env, mask_fn) env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env) callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
@ -132,7 +54,7 @@ def main():
model.learn(iterations, callback=callback) model.learn(iterations, callback=callback)
mean_reward, std_reward = evaluate_policy(model, model.get_env(), 10)
#W mean_reward, std_reward = evaluate_policy(model, model.get_env())
vec_env = model.get_env() vec_env = model.get_env()
obs = vec_env.reset() obs = vec_env.reset()

81
examples/shields/rl/ShieldHandlers.py

@ -0,0 +1,81 @@
import stormpy
import stormpy.core
import stormpy.simulator
import stormpy.shields
import stormpy.logic
import stormpy.examples
import stormpy.examples.files
from abc import ABC
import os
class ShieldHandler(ABC):
def __init__(self) -> None:
pass
def create_shield(self, **kwargs):
pass
class MiniGridShieldHandler(ShieldHandler):
def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:
self.grid_file = grid_file
self.grid_to_prism_path = grid_to_prism_path
self.prism_path = prism_path
self.formula = formula
def __export_grid_to_text(self, env):
f = open(self.grid_file, "w")
f.write(env.printGrid(init=True))
f.close()
def __create_prism(self):
os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
def __create_shield_dict(self):
program = stormpy.parse_prism_program(self.prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def create_shield(self, **kwargs):
env = kwargs["env"]
self.__export_grid_to_text(env)
self.__create_prism()
return self.__create_shield_dict()

3
examples/shields/rl/MaskModels.py → examples/shields/rl/TorchActionMaskModel.py

@ -11,7 +11,6 @@ torch, nn = try_import_torch()
class TorchActionMaskModel(TorchModelV2, nn.Module): class TorchActionMaskModel(TorchModelV2, nn.Module):
"""PyTorch version of above ActionMaskingModel."""
def __init__( def __init__(
self, self,
@ -23,7 +22,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
**kwargs, **kwargs,
): ):
orig_space = getattr(obs_space, "original_space", obs_space) orig_space = getattr(obs_space, "original_space", obs_space)
custom_config = model_config['custom_model_config']
TorchModelV2.__init__( TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name, **kwargs self, obs_space, action_space, num_outputs, model_config, name, **kwargs
@ -58,7 +56,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask masked_logits = logits + inf_mask
# Return masked logits. # Return masked logits.
return masked_logits, state return masked_logits, state

97
examples/shields/rl/Wrapper.py → examples/shields/rl/Wrappers.py

@ -7,10 +7,11 @@ from gymnasium.spaces import Dict, Box
from collections import deque from collections import deque
from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.numpy import one_hot
from helpers import get_action_index_mapping, create_shield_dict, extract_keys
from helpers import get_action_index_mapping, extract_keys
from ShieldHandlers import ShieldHandler
class OneHotWrapper(gym.core.ObservationWrapper):
class OneHotShieldingWrapper(gym.core.ObservationWrapper):
def __init__(self, env, vector_index, framestack): def __init__(self, env, vector_index, framestack):
super().__init__(env) super().__init__(env)
self.framestack = framestack self.framestack = framestack
@ -80,9 +81,9 @@ class OneHotWrapper(gym.core.ObservationWrapper):
return tmp return tmp
class MiniGridEnvWrapper(gym.core.Wrapper):
def __init__(self, env, args=None):
super(MiniGridEnvWrapper, self).__init__(env)
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = Dict( self.observation_space = Dict(
{ {
@ -90,7 +91,9 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
"action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8), "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
} }
) )
self.args = args
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
def create_action_mask(self): def create_action_mask(self):
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
@ -142,7 +145,10 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
self.shield = create_shield_dict(self.env, self.args)
if self.create_shield_at_reset:
self.shield = self.shield_creator.create_shield(env=self.env)
self.keys = extract_keys(self.env) self.keys = extract_keys(self.env)
mask = self.create_action_mask() mask = self.create_action_mask()
return { return {
@ -163,3 +169,80 @@ class MiniGridEnvWrapper(gym.core.Wrapper):
return obs, rew, done, truncated, info return obs, rew, done, truncated, info
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, no_masking=False):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.no_masking = no_masking
def create_action_mask(self):
if self.no_masking:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
allowed_actions = self.shield[cur_pos_str]
for allowed_action in allowed_actions:
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
# if front_tile is not None and front_tile.type == "key":
# mask[Actions.pickup] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
keys = extract_keys(self.env)
shield = self.shield_creator.create_shield(env=self.env)
self.keys = keys
self.shield = shield
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info

77
examples/shields/rl/helpers.py

@ -1,6 +1,5 @@
import minigrid import minigrid
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
import gymnasium as gym
from datetime import datetime from datetime import datetime
@ -14,7 +13,6 @@ import stormpy.logic
import stormpy.examples import stormpy.examples
import stormpy.examples.files import stormpy.examples.files
import os
def extract_keys(env): def extract_keys(env):
@ -66,17 +64,17 @@ def parse_arguments(argparse):
choices=[ choices=[
"MiniGrid-LavaCrossingS9N1-v0", "MiniGrid-LavaCrossingS9N1-v0",
"MiniGrid-LavaCrossingS9N3-v0", "MiniGrid-LavaCrossingS9N3-v0",
"MiniGrid-DoorKey-8x8-v0",
"MiniGrid-LockedRoom-v0",
"MiniGrid-FourRooms-v0",
"MiniGrid-LavaGapS7-v0",
"MiniGrid-SimpleCrossingS9N3-v0",
"MiniGrid-DoorKey-16x16-v0",
"MiniGrid-Empty-Random-6x6-v0",
# "MiniGrid-DoorKey-8x8-v0",
# "MiniGrid-LockedRoom-v0",
# "MiniGrid-FourRooms-v0",
# "MiniGrid-LavaGapS7-v0",
# "MiniGrid-SimpleCrossingS9N3-v0",
# "MiniGrid-DoorKey-16x16-v0",
# "MiniGrid-Empty-Random-6x6-v0",
]) ])
# parser.add_argument("--seed", type=int, help="seed for environment", default=None) # parser.add_argument("--seed", type=int, help="seed for environment", default=None)
parser.add_argument("--grid_to_prism_path", default="./main")
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid") parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid") parser.add_argument("--prism_path", default="grid")
parser.add_argument("--no_masking", default=False) parser.add_argument("--no_masking", default=False)
@ -90,62 +88,3 @@ def parse_arguments(argparse):
args = parser.parse_args() args = parser.parse_args()
return args return args
def export_grid_to_text(env, grid_file):
f = open(grid_file, "w")
# print(env)
f.write(env.printGrid(init=True))
f.close()
def create_shield(grid_to_prism_path, grid_file, prism_path, formula):
os.system(F"{grid_to_prism_path} -v 'agent' -i {grid_file} -o {prism_path}")
f = open(prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
f.close()
program = stormpy.parse_prism_program(prism_path)
shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1)
formulas = stormpy.parse_properties_for_prism_program(formula, program)
options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
options.set_build_state_valuations(True)
options.set_build_choice_labels(True)
options.set_build_all_labels()
model = stormpy.build_sparse_model_with_options(program, options)
result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
assert result.has_scheduler
assert result.has_shield
shield = result.shield
action_dictionary = {}
shield_scheduler = shield.construct()
for stateID in model.states:
choice = shield_scheduler.get_choice(stateID)
choices = choice.choice_map
state_valuation = model.state_valuations.get_string(stateID)
actions_to_be_executed = [(choice[1] ,model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]
action_dictionary[state_valuation] = actions_to_be_executed
stormpy.shields.export_shield(model, shield, "Grid.shield")
return action_dictionary
def create_shield_dict(env, args):
grid_file = args.grid_path
grid_to_prism_path = args.grid_to_prism_path
export_grid_to_text(env, grid_file)
prism_path = args.prism_path
shield_dict = create_shield(grid_to_prism_path ,grid_file, prism_path, args.formula)
return shield_dict
Loading…
Cancel
Save