Browse Source

added more shielding options

(training, evaluation, none, both)
refactoring
Thomas Knoll 2 years ago
parent
commit
e2c855dc6a
  1. 87
      examples/shields/rl/11_minigridrl.py
  2. 114
      examples/shields/rl/14_train_eval.py
  3. 7
      examples/shields/rl/TorchActionMaskModel.py
  4. 8
      examples/shields/rl/Wrappers.py
  5. 17
      examples/shields/rl/helpers.py

87
examples/shields/rl/11_minigridrl.py

@ -1,10 +1,10 @@
# from typing import Dict
# from ray.rllib.env.base_env import BaseEnv
# from ray.rllib.evaluation import RolloutWorker
# from ray.rllib.evaluation.episode import Episode
# from ray.rllib.evaluation.episode_v2 import EpisodeV2
# from ray.rllib.policy import Policy
# from ray.rllib.utils.typing import PolicyID
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import gymnasium as gym
@ -15,47 +15,47 @@ import minigrid
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger
# class MyCallbacks(DefaultCallbacks):
# def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# episode.user_data["count"] = 0
# # print("On episode start print")
# # print(env.printGrid())
# # print(worker)
# # print(env.action_space.n)
# # print(env.actions)
# # print(env.mission)
# # print(env.observation_space)
# # img = env.get_frame()
# # plt.imshow(img)
# # plt.show()
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
# def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# episode.user_data["count"] = episode.user_data["count"] + 1
# env = base_env.get_sub_environments()[0]
# # print(env.printGrid())
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
# def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# # print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
# env = base_env.get_sub_environments()[0]
# #print("On episode end print")
# #print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
@ -83,7 +83,7 @@ def shielding_env_creater(config):
def register_minigrid_shielding_env(args):
env_name = "mini-grid"
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
@ -98,25 +98,21 @@ def ppo(args):
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment(env="mini-grid", env_config={"name": args.env, "args": args})
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.framework("torch")
#.callbacks(MyCallbacks)
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking}
"custom_model": "shielding_model"
}))
algo =(
algo =(
config.build()
)
algo.eva
)
for i in range(args.iterations):
result = algo.train()
@ -134,7 +130,7 @@ def dqn(args):
config = DQNConfig()
config = config.resources(num_gpus=0)
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid", env_config={"name": args.env, "args": args })
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch")
#config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
@ -143,8 +139,7 @@ def dqn(args):
"logdir": create_log_dir(args)
})
config = config.training(hiddens=[], dueling=False, model={
"custom_model": "shielding_model",
"custom_model_config" : {"no_masking": args.no_masking}
"custom_model": "shielding_model"
})
algo = (

114
examples/shields/rl/14_train_eval.py

@ -0,0 +1,114 @@
import gymnasium as gym
import minigrid
# import numpy as np
# import ray
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
# if shielding:
# assert(False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)
def ppo(args):
register_minigrid_shielding_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.framework("torch")
.evaluation(evaluation_config={ "evaluation_interval": 1,
"evaluation_parallel_to_training": False,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}})
#.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
algo =(
config.build()
)
iterations = args.iterations
for i in range(iterations):
algo.train()
if i % 5 == 0:
algo.save()
for i in range(iterations):
eval_result = algo.evaluate()
print(pretty_print(eval_result))
def main():
import argparse
args = parse_arguments(argparse)
ppo(args)
if __name__ == '__main__':
main()

7
examples/shields/rl/TorchActionMaskModel.py

@ -38,9 +38,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
name + "_internal",
)
self.no_masking = False
if "no_masking" in model_config["custom_model_config"]:
self.no_masking = model_config["custom_model_config"]["no_masking"]
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
@ -48,10 +45,6 @@ class TorchActionMaskModel(TorchModelV2, nn.Module):
logits, _ = self.internal_model({"obs": input_dict["obs"]["data"]})
action_mask = input_dict["obs"]["action_mask"]
# If action masking is disabled, directly return unmasked logits
if self.no_masking:
return logits, state
inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
masked_logits = logits + inf_mask

8
examples/shields/rl/Wrappers.py

@ -82,7 +82,7 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
@ -94,8 +94,12 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
self.shield_creator = shield_creator
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
@ -146,7 +150,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
if self.create_shield_at_reset:
if self.create_shield_at_reset and self.mask_actions:
self.shield = self.shield_creator.create_shield(env=self.env)
self.keys = extract_keys(self.env)

17
examples/shields/rl/helpers.py

@ -2,6 +2,9 @@ import minigrid
from minigrid.core.actions import Actions
from datetime import datetime
from enum import Enum
import os
import stormpy
import stormpy.core
@ -13,8 +16,16 @@ import stormpy.logic
import stormpy.examples
import stormpy.examples.files
class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Enabled = 'full'
def __str__(self) -> str:
return self.value
def extract_keys(env):
keys = []
#print(env.grid)
@ -28,7 +39,7 @@ def extract_keys(env):
return keys
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-masking:{not args.no_masking}-env:{args.env}"
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}"
def get_action_index_mapping(actions):
@ -77,12 +88,12 @@ def parse_arguments(argparse):
parser.add_argument("--grid_to_prism_binary_path", default="./main")
parser.add_argument("--grid_path", default="grid")
parser.add_argument("--prism_path", default="grid")
parser.add_argument("--no_masking", default=False)
parser.add_argument("--algorithm", default="ppo", choices=["ppo", "dqn"])
parser.add_argument("--log_dir", default="../log_results/")
parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled)
args = parser.parse_args()

Loading…
Cancel
Save