The requisites for applying a shield while training a RL Agent in the Minigrid Environment with PPO Algorithm are:

# Binaries
- Tempest
- Minigrid2Prism


# Python packages:
- Tempestpy
- Minigrid with the printGrid Function
- ray / rllib

The shield handler is responsible for creating and querying the shield.

In [None]:

import stormpy
import stormpy.core
import stormpy.simulator

import stormpy.shields
import stormpy.logic

import stormpy.examples
import stormpy.examples.files

from abc import ABC

import os

class Action():
    def __init__(self, idx, prob=1, labels=[]) -> None:
        self.idx = idx
        self.prob = prob
        self.labels = labels

class ShieldHandler(ABC):
    def __init__(self) -> None:
        pass
    def create_shield(self, **kwargs) -> dict:
        pass

class MiniGridShieldHandler(ShieldHandler):
    def __init__(self, grid_file, grid_to_prism_path, prism_path, formula) -> None:
        self.grid_file = grid_file
        self.grid_to_prism_path = grid_to_prism_path
        self.prism_path = prism_path
        self.formula = formula
    
    def __export_grid_to_text(self, env):
        f = open(self.grid_file, "w")
        f.write(env.printGrid(init=True))
        f.close()

    
    def __create_prism(self):
        result = os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
    
        assert result == 0, "Prism file could not be generated"
    
        f = open(self.prism_path, "a")
        f.write("label \"AgentIsInLava\" = AgentIsInLava;")
        f.close()
        
    def __create_shield_dict(self):
        program = stormpy.parse_prism_program(self.prism_path)
        shield_specification = stormpy.logic.ShieldExpression(stormpy.logic.ShieldingType.PRE_SAFETY, stormpy.logic.ShieldComparison.RELATIVE, 0.1) 
        
        formulas = stormpy.parse_properties_for_prism_program(self.formula, program)
        options = stormpy.BuilderOptions([p.raw_formula for p in formulas])
        options.set_build_state_valuations(True)
        options.set_build_choice_labels(True)
        options.set_build_all_labels()
        model = stormpy.build_sparse_model_with_options(program, options)
        
        result = stormpy.model_checking(model, formulas[0], extract_scheduler=True, shield_expression=shield_specification)
        
        assert result.has_scheduler
        assert result.has_shield
        shield = result.shield
        
        action_dictionary = {}
        shield_scheduler = shield.construct()
        
        for stateID in model.states:
            choice = shield_scheduler.get_choice(stateID)
            choices = choice.choice_map
            state_valuation = model.state_valuations.get_string(stateID)

            actions_to_be_executed = [Action(idx= choice[1], prob=choice[0], labels=model.choice_labeling.get_labels_of_choice(model.get_choice_index(stateID, choice[1]))) for choice in choices]


            action_dictionary[state_valuation] = actions_to_be_executed

        return action_dictionary
    
    
    def create_shield(self, **kwargs):
        env = kwargs["env"]
        self.__export_grid_to_text(env)
        self.__create_prism()
       
        return self.__create_shield_dict()
        
def create_shield_query(env):
    coordinates = env.env.agent_pos
    view_direction = env.env.agent_dir

    cur_pos_str = f"[!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"

    return cur_pos_str
    

To train a learning algorithm with shielding the allowed actions need to be embedded in the observation. 
This can be done by implementing a gym wrapper handling the action embedding for the enviornment.

In [None]:
import gymnasium as gym
import numpy as np
import random

from minigrid.core.actions import Actions

from gymnasium.spaces import Dict, Box
from collections import deque
from ray.rllib.utils.numpy import one_hot

from helpers import get_action_index_mapping, extract_keys

class OneHotShieldingWrapper(gym.core.ObservationWrapper):
    def __init__(self, env, vector_index, framestack):
        super().__init__(env)
        self.framestack = framestack
        # 49=7x7 field of vision; 16=object types; 6=colors; 3=state types.
        # +4: Direction.
        self.single_frame_dim = 49 * (16 + 6 + 3) + 4
        self.init_x = None
        self.init_y = None
        self.x_positions = []
        self.y_positions = []
        self.x_y_delta_buffer = deque(maxlen=100)
        self.vector_index = vector_index
        self.frame_buffer = deque(maxlen=self.framestack)
        for _ in range(self.framestack):
            self.frame_buffer.append(np.zeros((self.single_frame_dim,)))

        self.observation_space = Dict(
            {
                "data": gym.spaces.Box(0.0, 1.0, shape=(self.single_frame_dim * self.framestack,), dtype=np.float32),
                "action_mask": gym.spaces.Box(0, 10, shape=(env.action_space.n,), dtype=int),
            }
            )

    def observation(self, obs):
        # Debug output: max-x/y positions to watch exploration progress.
        # print(F"Initial observation in Wrapper {obs}")
        if self.step_count == 0:
            for _ in range(self.framestack):
                self.frame_buffer.append(np.zeros((self.single_frame_dim,)))
            if self.vector_index == 0:
                if self.x_positions:
                    max_diff = max(
                        np.sqrt(
                            (np.array(self.x_positions) - self.init_x) ** 2
                            + (np.array(self.y_positions) - self.init_y) ** 2
                        )
                    )
                    self.x_y_delta_buffer.append(max_diff)
                    print(
                        "100-average dist travelled={}".format(
                            np.mean(self.x_y_delta_buffer)
                        )
                    )
                    self.x_positions = []
                    self.y_positions = []
                self.init_x = self.agent_pos[0]
                self.init_y = self.agent_pos[1]


        self.x_positions.append(self.agent_pos[0])
        self.y_positions.append(self.agent_pos[1])

        image = obs["data"]

        # One-hot the last dim into 11, 6, 3 one-hot vectors, then flatten.
        objects = one_hot(image[:, :, 0], depth=16)
        colors = one_hot(image[:, :, 1], depth=6)
        states = one_hot(image[:, :, 2], depth=3)

        all_ = np.concatenate([objects, colors, states], -1)
        all_flat = np.reshape(all_, (-1,))
        direction = one_hot(np.array(self.agent_dir), depth=4).astype(np.float32)
        single_frame = np.concatenate([all_flat, direction])
        self.frame_buffer.append(single_frame)

        tmp = {"data": np.concatenate(self.frame_buffer), "action_mask": obs["action_mask"] }
        return tmp

# Environment wrapper handling action embedding in observations
class MiniGridShieldingWrapper(gym.core.Wrapper):
    def __init__(self, 
                 env, 
                shield_creator : ShieldHandler,
                shield_query_creator,
                create_shield_at_reset=True,    
                mask_actions=True):
        super(MiniGridShieldingWrapper, self).__init__(env)
        self.max_available_actions = env.action_space.n
        self.observation_space = Dict(
            {
                "data": env.observation_space.spaces["image"],
                "action_mask" : Box(0, 10, shape=(self.max_available_actions,), dtype=np.int8),
            }
        )
        self.shield_creator = shield_creator
        self.create_shield_at_reset = create_shield_at_reset
        self.shield = shield_creator.create_shield(env=self.env)
        self.mask_actions = mask_actions
        self.shield_query_creator = shield_query_creator

    def create_action_mask(self):
        if not self.mask_actions:
            return np.array([1.0] * self.max_available_actions, dtype=np.int8)
        
        cur_pos_str = self.shield_query_creator(self.env)
      
        # Create the mask
        # If shield restricts action mask only valid with 1.0
        # else set all actions as valid
        allowed_actions = []
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)

        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
            allowed_actions = self.shield[cur_pos_str]
            for allowed_action in allowed_actions:
                index =  get_action_index_mapping(allowed_action.labels) # Allowed_action is a set
                if index is None:
                    assert(False)
                
                allowed =  random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
                mask[index] = allowed               
                     
        else:
            for index, x in enumerate(mask):
                mask[index] = 1.0
        
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])

        if front_tile is not None and front_tile.type == "key":
            mask[Actions.pickup] = 1.0
        
        if front_tile and front_tile.type == "door":
            mask[Actions.toggle] = 1.0
            
        return mask

    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)
        
        if self.create_shield_at_reset and self.mask_actions:
            self.shield = self.shield_creator.create_shield(env=self.env)
        
        self.keys = extract_keys(self.env)
        mask = self.create_action_mask()
        return {
            "data": obs["image"],
            "action_mask": mask
        }, infos

    def step(self, action):
        orig_obs, rew, done, truncated, info = self.env.step(action)

        mask = self.create_action_mask()
        obs = {
            "data": orig_obs["image"],
            "action_mask": mask,
        }
        
        return obs, rew, done, truncated, info


# Wrapper to use with a stable baseline algorithm
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
    def __init__(self, 
                 env, 
                 shield_creator : ShieldHandler,
                 shield_query_creator,
                 create_shield_at_reset = True,
                 mask_actions=True,
                 ):
        super(MiniGridSbShieldingWrapper, self).__init__(env)
        self.max_available_actions = env.action_space.n
        self.observation_space = env.observation_space.spaces["image"]
        
        self.shield_creator = shield_creator
        self.mask_actions = mask_actions
        self.shield_query_creator = shield_query_creator

    def create_action_mask(self):
        if not self.mask_actions:
            return  np.array([1.0] * self.max_available_actions, dtype=np.int8)
               
        cur_pos_str = self.shield_query_creator(self.env)
        
        allowed_actions = []

        # Create the mask
        # If shield restricts actions, mask only valid actions with 1.0
        # else set all actions valid
        mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)

        if cur_pos_str in self.shield and self.shield[cur_pos_str]:
            allowed_actions = self.shield[cur_pos_str]
            for allowed_action in allowed_actions:
                index =  get_action_index_mapping(allowed_action.labels)
                if index is None:
                     assert(False)
                
                
                mask[index] = random.choices([0.0, 1.0], weights=(1 - allowed_action.prob, allowed_action.prob))[0]
        else:
            for index, x in enumerate(mask):
                mask[index] = 1.0
        
        front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])

            
        if front_tile and front_tile.type == "door":
            mask[Actions.toggle] = 1.0            
            
        return mask  
    

    def reset(self, *, seed=None, options=None):
        obs, infos = self.env.reset(seed=seed, options=options)
      
        keys = extract_keys(self.env)
        shield = self.shield_creator.create_shield(env=self.env)
        
        self.keys = keys
        self.shield = shield
        return obs["image"], infos

    def step(self, action):
        orig_obs, rew, done, truncated, info = self.env.step(action)
        obs = orig_obs["image"]
        
        return obs, rew, done, truncated, info



If we want to use rllib algorithms we additionaly need a model which performs the action masking.

In [None]:
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX

torch, nn = try_import_torch()

class TorchActionMaskModel(TorchModelV2, nn.Module):

    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, "original_space", obs_space)
        
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )
        nn.Module.__init__(self)
        
        self.count = 0

        self.internal_model = TorchFC(
            orig_space["data"],
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )
        

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
   
        action_mask = input_dict["obs"]["action_mask"]

        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        masked_logits = logits + inf_mask

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        return self.internal_model.value_function()

Using these components we can now train an rl agent with shielding.

In [None]:
import gymnasium as gym
import minigrid

from ray import tune, air
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog


def shielding_env_creater(config):
    name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
    framestack = config.get("framestack", 4)
    
    shield_creator = MiniGridShieldHandler("grid.txt", "./main", "grid.prism", "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]")
    
    env = gym.make(name)
    env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)
    env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0,
                                 framestack=framestack)
    
    return env


def register_minigrid_shielding_env():
    env_name = "mini-grid-shielding"
    register_env(env_name, shielding_env_creater)
    ModelCatalog.register_custom_model(
        "shielding_model", 
        TorchActionMaskModel)

register_minigrid_shielding_env()


config = (PPOConfig()
    .rollouts(num_rollout_workers=1)
    .resources(num_gpus=0)
    .environment(env="mini-grid-shielding", env_config={"name": "MiniGrid-LavaCrossingS9N1-v0"})
    .framework("torch")
    .rl_module(_enable_rl_module_api = False)
    .training(_enable_learner_api=False ,model={
        "custom_model": "shielding_model"
    }))

tuner = tune.Tuner("PPO",
                    tune_config=tune.TuneConfig(
                        metric="episode_reward_mean",
                        mode="max",
                        num_samples=1,
                        
                    ),
                    run_config=air.RunConfig(
                            stop = {"episode_reward_mean": 94,
                                    "timesteps_total": 12000,
                                    "training_iteration": 12}, 
                            checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
                    ),
                    param_space=config,)

results = tuner.fit()
best_result = results.get_best_result()

import pprint

metrics_to_print = [
"episode_reward_mean",
"episode_reward_max",
"episode_reward_min",
"episode_len_mean",
]
pprint.pprint({k: v for k, v in best_result.metrics.items() if k in metrics_to_print})

      