Browse Source

added more shielding options

(training, evaluation, none, both)
refactoring
Thomas Knoll 2 years ago
parent
commit
e2c855dc6a
  1. 87
      examples/shields/rl/11_minigridrl.py
  2. 114
      examples/shields/rl/14_train_eval.py
  3. 7
      examples/shields/rl/TorchActionMaskModel.py
  4. 8
      examples/shields/rl/Wrappers.py
  5. 17
      examples/shields/rl/helpers.py

87
examples/shields/rl/11_minigridrl.py

@ -1,10 +1,10 @@
# from typing import Dict from typing import Dict
# from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.base_env import BaseEnv
# from ray.rllib.evaluation import RolloutWorker from ray.rllib.evaluation import RolloutWorker
# from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode import Episode
# from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.episode_v2 import EpisodeV2
# from ray.rllib.policy import Policy from ray.rllib.policy import Policy
# from ray.rllib.utils.typing import PolicyID from ray.rllib.utils.typing import PolicyID
import gymnasium as gym import gymnasium as gym
@ -15,47 +15,47 @@ import minigrid
from ray.tune import register_env from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger from ray.tune.logger import TBXLogger
# class MyCallbacks(DefaultCallbacks): class MyCallbacks(DefaultCallbacks):
# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
# episode.user_data["count"] = 0 episode.user_data["count"] = 0
# # print("On episode start print") # print("On episode start print")
# # print(env.printGrid()) # print(env.printGrid())
# # print(worker) # print(worker)
# # print(env.action_space.n) # print(env.action_space.n)
# # print(env.actions) # print(env.actions)
# # print(env.mission) # print(env.mission)
# # print(env.observation_space) # print(env.observation_space)
# # img = env.get_frame() # img = env.get_frame()
# # plt.imshow(img) # plt.imshow(img)
# # plt.show() # plt.show()
# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# episode.user_data["count"] = episode.user_data["count"] + 1 episode.user_data["count"] = episode.user_data["count"] + 1
# env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
# # print(env.printGrid()) # print(env.printGrid())
# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0] env = base_env.get_sub_environments()[0]
# #print("On episode end print") #print("On episode end print")
# #print(env.printGrid()) #print(env.printGrid())
@ -83,7 +83,7 @@ def shielding_env_creater(config):
def register_minigrid_shielding_env(args): def register_minigrid_shielding_env(args):
env_name = "mini-grid" env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater) register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model( ModelCatalog.register_custom_model(
@ -98,25 +98,21 @@ def ppo(args):
config = (PPOConfig() config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers) .rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0) .resources(num_gpus=0)
.environment(env="mini-grid", env_config={"name": args.env, "args": args}) .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.framework("torch") .framework("torch")
#.callbacks(MyCallbacks) .callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False) .rl_module(_enable_rl_module_api = False)
.debugging(logger_config={ .debugging(logger_config={
"type": TBXLogger, "type": TBXLogger,
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
.training(_enable_learner_api=False ,model={ .training(_enable_learner_api=False ,model={
"custom_model": "shielding_model", "custom_model": "shielding_model"
"custom_model_config" : {"no_masking": args.no_masking}
})) }))
algo =( algo =(
config.build() config.build()
) )
algo.eva
for i in range(args.iterations): for i in range(args.iterations):
result = algo.train() result = algo.train()
@ -134,7 +130,7 @@ def dqn(args):
config = DQNConfig() config = DQNConfig()
config = config.resources(num_gpus=0) config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=args.workers) config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args }) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch") config = config.framework("torch")
#config = config.callbacks(MyCallbacks) #config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False) config = config.rl_module(_enable_rl_module_api = False)
@ -143,8 +139,7 @@ def dqn(args):
"logdir": create_log_dir(args) "logdir": create_log_dir(args)
}) })
config = config.training(hiddens=[], dueling=False, model={ config = config.training(hiddens=[], dueling=False, model={
"custom_model": "shielding_model", "custom_model": "shielding_model"
"custom_model_config" : {"no_masking": args.no_masking}
}) })
algo = ( algo = (

114
examples/shields/rl/14_train_eval.py

@ -0,0 +1,114 @@
import gymnasium as gym
import minigrid
# import numpy as np
# import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
# if shielding:
# assert(False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)
def ppo(args):
register_minigrid_shielding_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.framework("torch")
.evaluation(evaluation_config={ "evaluation_interval": 1,
"evaluation_parallel_to_training": False,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}})
#.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
algo =(
config.build()
)
iterations = args.iterations
for i in range(iterations):
algo.train()
if i % 5 == 0:
algo.save()
for i in range(iterations):
eval_result = algo.evaluate()
print(pretty_print(eval_result))
def main():
import argparse
args = parse_arguments(argparse)
ppo(args)
if __name__ == '__main__':
main()

7
examples/shields/rl/TorchActionMaskModel.py

@ -38,9 +38,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
name + "_internal", name + "_internal",
) )
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
def forward(self, input_dict, state, seq_lens): def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation. # Extract the available actions tensor from the observation.
@ -48,10 +45,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]}) logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
action_mask = input_dict["obs"]["action_mask"] action_mask = input_dict["obs"]["action_mask"]
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN) inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask masked_logits = logits + inf_mask

8
examples/shields/rl/Wrappers.py

@ -82,7 +82,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
class MiniGridShieldingWrapper(gym.core.Wrapper): class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True): def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env) super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n self.max_available_actions = env.action_space.n
self.observation_space = Dict( self.observation_space = Dict(
@ -94,8 +94,12 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
self.shield_creator = shield_creator self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env) self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
def create_action_mask(self): def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos coordinates = self.env.agent_pos
view_direction = self.env.agent_dir view_direction = self.env.agent_dir
@ -146,7 +150,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None): def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options) obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset: if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env) self.shield = self.shield_creator.create_shield(env=self.env)
self.keys = extract_keys(self.env) self.keys = extract_keys(self.env)

17
examples/shields/rl/helpers.py

@ -2,6 +2,9 @@ import minigrid
from minigrid.core.actions import Actions from minigrid.core.actions import Actions
from datetime import datetime from datetime import datetime
from enum import Enum
import os
import stormpy import stormpy
import stormpy.core import stormpy.core
@ -13,8 +16,16 @@ import stormpy.logic
import stormpy.examples import stormpy.examples
import stormpy.examples.files import stormpy.examples.files
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Enabled = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env): def extract_keys(env):
keys = [] keys = []
#print(env.grid) #print(env.grid)
@ -28,7 +39,7 @@ def extract_keys(env):
return keys return keys
def create_log_dir(args): def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}" return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}"
def get_action_index_mapping(actions): def get_action_index_mapping(actions):
@ -77,12 +88,12 @@ def parse_arguments(argparse):
parser.add_argument("--grid_to_prism_binary_path", default="./main") parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid") parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid") parser.add_argument("--prism_path", default="grid")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"]) parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/") parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 ) parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1) parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled)
args = parser.parse_args() args = parser.parse_args()

|||||||
100:0
Loading…
Cancel
Save