diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 4ce06a7..9f25be6 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -1,10 +1,10 @@ -# from typing import Dict -# from ray.rllib.env.base_env import BaseEnv -# from ray.rllib.evaluation import RolloutWorker -# from ray.rllib.evaluation.episode import Episode -# from ray.rllib.evaluation.episode_v2 import EpisodeV2 -# from ray.rllib.policy import Policy -# from ray.rllib.utils.typing import PolicyID +from typing import Dict +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.evaluation import RolloutWorker +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.episode_v2 import EpisodeV2 +from ray.rllib.policy import Policy +from ray.rllib.utils.typing import PolicyID import gymnasium as gym @@ -15,47 +15,47 @@ import minigrid from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig -# from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.models import ModelCatalog from TorchActionMaskModel import TorchActionMaskModel from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper -from helpers import parse_arguments, create_log_dir +from helpers import parse_arguments, create_log_dir, ShieldingConfig from ShieldHandlers import MiniGridShieldHandler import matplotlib.pyplot as plt from ray.tune.logger import TBXLogger -# class MyCallbacks(DefaultCallbacks): -# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: -# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") -# env = base_env.get_sub_environments()[0] -# episode.user_data["count"] = 0 -# # print("On episode start print") -# # print(env.printGrid()) -# # print(worker) -# # print(env.action_space.n) -# # print(env.actions) -# # print(env.mission) -# # print(env.observation_space) -# # img = env.get_frame() -# # plt.imshow(img) -# # plt.show() +class MyCallbacks(DefaultCallbacks): + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + episode.user_data["count"] = 0 + # print("On episode start print") + # print(env.printGrid()) + # print(worker) + # print(env.action_space.n) + # print(env.actions) + # print(env.mission) + # print(env.observation_space) + # img = env.get_frame() + # plt.imshow(img) + # plt.show() -# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: -# episode.user_data["count"] = episode.user_data["count"] + 1 -# env = base_env.get_sub_environments()[0] -# # print(env.printGrid()) + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + episode.user_data["count"] = episode.user_data["count"] + 1 + env = base_env.get_sub_environments()[0] + # print(env.printGrid()) -# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: -# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") -# env = base_env.get_sub_environments()[0] -# #print("On episode end print") -# #print(env.printGrid()) + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + #print("On episode end print") + #print(env.printGrid()) @@ -83,7 +83,7 @@ def shielding_env_creater(config): def register_minigrid_shielding_env(args): - env_name = "mini-grid" + env_name = "mini-grid-shielding" register_env(env_name, shielding_env_creater) ModelCatalog.register_custom_model( @@ -98,25 +98,21 @@ def ppo(args): config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) - .environment(env="mini-grid", env_config={"name": args.env, "args": args}) + .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training}) .framework("torch") - #.callbacks(MyCallbacks) + .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ "type": TBXLogger, "logdir": create_log_dir(args) }) .training(_enable_learner_api=False ,model={ - "custom_model": "shielding_model", - "custom_model_config" : {"no_masking": args.no_masking} + "custom_model": "shielding_model" })) - algo =( - + algo =( config.build() - ) - - algo.eva + ) for i in range(args.iterations): result = algo.train() @@ -134,7 +130,7 @@ def dqn(args): config = DQNConfig() config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=args.workers) - config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) + config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.framework("torch") #config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) @@ -143,8 +139,7 @@ def dqn(args): "logdir": create_log_dir(args) }) config = config.training(hiddens=[], dueling=False, model={ - "custom_model": "shielding_model", - "custom_model_config" : {"no_masking": args.no_masking} + "custom_model": "shielding_model" }) algo = ( diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py new file mode 100644 index 0000000..047873e --- /dev/null +++ b/examples/shields/rl/14_train_eval.py @@ -0,0 +1,114 @@ + +import gymnasium as gym + +import minigrid +# import numpy as np + +# import ray +from ray.tune import register_env +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.dqn.dqn import DQNConfig +# from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.tune.logger import pretty_print +from ray.rllib.models import ModelCatalog + + +from TorchActionMaskModel import TorchActionMaskModel +from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper +from helpers import parse_arguments, create_log_dir, ShieldingConfig +from ShieldHandlers import MiniGridShieldHandler + +import matplotlib.pyplot as plt + +from ray.tune.logger import TBXLogger + + + +def shielding_env_creater(config): + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + framestack = config.get("framestack", 4) + args = config.get("args", None) + args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + + shielding = config.get("shielding", False) + + # if shielding: + # assert(False) + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + + env = gym.make(name) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding) + + env = OneHotShieldingWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + + return env + + +def register_minigrid_shielding_env(args): + env_name = "mini-grid-shielding" + register_env(env_name, shielding_env_creater) + + ModelCatalog.register_custom_model( + "shielding_model", + TorchActionMaskModel + ) + + +def ppo(args): + register_minigrid_shielding_env(args) + + config = (PPOConfig() + .rollouts(num_rollout_workers=args.workers) + .resources(num_gpus=0) + .environment( env="mini-grid-shielding", + env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training}) + .framework("torch") + .evaluation(evaluation_config={ "evaluation_interval": 1, + "evaluation_parallel_to_training": False, + "env": "mini-grid-shielding", + "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}}) + #.callbacks(MyCallbacks) + .rl_module(_enable_rl_module_api = False) + .debugging(logger_config={ + "type": TBXLogger, + "logdir": create_log_dir(args) + }) + .training(_enable_learner_api=False ,model={ + "custom_model": "shielding_model" + })) + + algo =( + + config.build() + ) + + iterations = args.iterations + + for i in range(iterations): + algo.train() + + if i % 5 == 0: + algo.save() + + + for i in range(iterations): + eval_result = algo.evaluate() + print(pretty_print(eval_result)) + + +def main(): + import argparse + args = parse_arguments(argparse) + + ppo(args) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/shields/rl/TorchActionMaskModel.py b/examples/shields/rl/TorchActionMaskModel.py index 42b6805..b478636 100644 --- a/examples/shields/rl/TorchActionMaskModel.py +++ b/examples/shields/rl/TorchActionMaskModel.py @@ -38,9 +38,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): name + "_internal", ) - self.no_masking = False - if "no_masking" in model_config["custom_model_config"]: - self.no_masking = model_config["custom_model_config"]["no_masking"] def forward(self, input_dict, state, seq_lens): # Extract the available actions tensor from the observation. @@ -48,10 +45,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module): logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) action_mask = input_dict["obs"]["action_mask"] - - # If action masking is disabled, directly return unmasked logits - if self.no_masking: - return logits, state inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) masked_logits = logits + inf_mask diff --git a/examples/shields/rl/Wrappers.py b/examples/shields/rl/Wrappers.py index ef761fa..0e5cc05 100644 --- a/examples/shields/rl/Wrappers.py +++ b/examples/shields/rl/Wrappers.py @@ -82,7 +82,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): class MiniGridShieldingWrapper(gym.core.Wrapper): - def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True): + def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True): super(MiniGridShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( @@ -94,8 +94,12 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): self.shield_creator = shield_creator self.create_shield_at_reset = create_shield_at_reset self.shield = shield_creator.create_shield(env=self.env) + self.mask_actions = mask_actions def create_action_mask(self): + if not self.mask_actions: + return np.array([1.0] * self.max_available_actions, dtype=np.int8) + coordinates = self.env.agent_pos view_direction = self.env.agent_dir @@ -146,7 +150,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) - if self.create_shield_at_reset: + if self.create_shield_at_reset and self.mask_actions: self.shield = self.shield_creator.create_shield(env=self.env) self.keys = extract_keys(self.env) diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 59af016..5863d99 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -2,6 +2,9 @@ import minigrid from minigrid.core.actions import Actions from datetime import datetime +from enum import Enum + +import os import stormpy import stormpy.core @@ -13,8 +16,16 @@ import stormpy.logic import stormpy.examples import stormpy.examples.files +class ShieldingConfig(Enum): + Training = 'training' + Evaluation = 'evaluation' + Disabled = 'none' + Enabled = 'full' + + def __str__(self) -> str: + return self.value + - def extract_keys(env): keys = [] #print(env.grid) @@ -28,7 +39,7 @@ def extract_keys(env): return keys def create_log_dir(args): - return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}" + return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}" def get_action_index_mapping(actions): @@ -77,12 +88,12 @@ def parse_arguments(argparse): parser.add_argument("--grid_to_prism_binary_path", default="./main") parser.add_argument("--grid_path", default="grid") parser.add_argument("--prism_path", default="grid") - parser.add_argument("--no_masking", default=False) parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--iterations", type=int, default=30 ) parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--workers", type=int, default=1) + parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled) args = parser.parse_args()