## Example how to combine shielding with rllibs dqn algorithm.

In [None]:
import gymnasium as gym

import minigrid

from ray import tune, air
from ray.tune import register_env
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog


from torch_action_mask_model import TorchActionMaskModel
from wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from shieldhandlers import MiniGridShieldHandler, create_shield_query
 

In [None]:
def shielding_env_creater(config):
 name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
 framestack = config.get("framestack", 4)
 
 shield_creator = MiniGridShieldHandler("grid.txt", "./main", "grid.prism", "Pmax=? [G !\"AgentIsInLavaAndNotDone\"]")
 
 env = gym.make(name)
 env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=True)
 env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0,
 framestack=framestack)
 
 return env


def register_minigrid_shielding_env():
 env_name = "mini-grid-shielding"
 register_env(env_name, shielding_env_creater)
 ModelCatalog.register_custom_model(
 "shielding_model", 
 TorchActionMaskModel)

In [None]:
register_minigrid_shielding_env()

 
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=1)
config = config.environment(env="mini-grid-shielding", env_config={"name": "MiniGrid-LavaCrossingS9N1-v0" })
config = config.framework("torch")
config = config.rl_module(_enable_rl_module_api = False)
config = config.training(hiddens=[], dueling=False, model={ 
 "custom_model": "shielding_model"
})

tuner = tune.Tuner("DQN",
 tune_config=tune.TuneConfig(
 metric="episode_reward_mean",
 mode="max",
 num_samples=1,
 
 ),
 run_config=air.RunConfig(
 stop = {"episode_reward_mean": 94,
 "timesteps_total": 12000,
 "training_iteration": 12}, 
 checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, num_to_keep=2 ),
 ),
 param_space=config,)

tuner.fit()
