Browse Source

added tune example

refactored and evaluation logging
refactoring
Thomas Knoll 1 year ago
parent
commit
138d917fd6
  1. 48
      examples/shields/rl/11_minigridrl.py
  2. 9
      examples/shields/rl/13_minigridsb.py
  3. 44
      examples/shields/rl/14_train_eval.py
  4. 118
      examples/shields/rl/15_train_eval_tune.py
  5. 21
      examples/shields/rl/ShieldHandlers.py
  6. 75
      examples/shields/rl/Wrappers.py
  7. 61
      examples/shields/rl/callbacks.py
  8. 6
      examples/shields/rl/helpers.py

48
examples/shields/rl/11_minigridrl.py

@ -1,10 +1,4 @@
from typing import Dict
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
import gymnasium as gym
@ -15,7 +9,6 @@ import minigrid
from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
@ -23,42 +16,13 @@ from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt
from ray.tune.logger import TBXLogger
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
#print("On episode end print")
#print(env.printGrid())
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
@ -69,7 +33,7 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query)
# env = minigrid.wrappers.ImgObsWrapper(env)
# env = ImgObsWrapper(env)
env = OneHotShieldingWrapper(env,
@ -98,7 +62,7 @@ def ppo(args):
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.rl_module(_enable_rl_module_api = False)
@ -132,7 +96,7 @@ def dqn(args):
config = config.rollouts(num_rollout_workers=args.workers)
config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args })
config = config.framework("torch")
#config = config.callbacks(MyCallbacks)
config = config.callbacks(MyCallbacks)
config = config.rl_module(_enable_rl_module_api = False)
config = config.debugging(logger_config={
"type": TBXLogger,

9
examples/shields/rl/13_minigridsb.py

@ -11,8 +11,8 @@ from minigrid.core.actions import Actions
import numpy as np
import time
from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir
from ShieldHandlers import MiniGridShieldHandler
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from Wrappers import MiniGridSbShieldingWrapper
class CustomCallback(BaseCallback):
@ -27,6 +27,7 @@ class CustomCallback(BaseCallback):
def mask_fn(env: gym.Env):
return env.create_action_mask()
@ -42,10 +43,10 @@ def main():
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(args.env, render_mode="rgb_array")
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking)
env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
env = ActionMasker(env, mask_fn)
callback = CustomCallback(1, env)
model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args))
model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
iterations = args.iterations

44
examples/shields/rl/14_train_eval.py

@ -9,18 +9,20 @@ from ray.tune import register_env
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print
from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger, CSVLogger
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from ray.tune.logger import TBXLogger
@ -39,7 +41,7 @@ def shielding_env_creater(config):
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
@ -67,16 +69,18 @@ def ppo(args):
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training})
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.evaluation(evaluation_config={ "evaluation_interval": 1,
"evaluation_parallel_to_training": False,
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}})
#.callbacks(MyCallbacks)
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": TBXLogger,
"type": UnifiedLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
@ -90,17 +94,35 @@ def ppo(args):
iterations = args.iterations
for i in range(iterations):
algo.train()
if i % 5 == 0:
algo.save()
writer = SummaryWriter(log_dir=F"{create_log_dir(args)}-eval")
csv_logger = CSVLogger()
for i in range(iterations):
eval_result = algo.evaluate()
print(pretty_print(eval_result))
print(eval_result)
# logger.on_result(eval_result)
evaluation = eval_result['evaluation']
epsiode_reward_mean = evaluation['episode_reward_mean']
episode_len_mean = evaluation['episode_len_mean']
print(epsiode_reward_mean)
writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i)
writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i)
writer.close()
def main():
import argparse

118
examples/shields/rl/15_train_eval_tune.py

@ -0,0 +1,118 @@
import gymnasium as gym
import minigrid
# import numpy as np
# import ray
from ray.tune import register_env
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.dqn.dqn import DQNConfig
# from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger
from ray.rllib.models import ModelCatalog
from TorchActionMaskModel import TorchActionMaskModel
from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper
from helpers import parse_arguments, create_log_dir, ShieldingConfig
from ShieldHandlers import MiniGridShieldHandler, create_shield_query
from callbacks import MyCallbacks
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
def shielding_env_creater(config):
name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0")
framestack = config.get("framestack", 4)
args = config.get("args", None)
args.grid_path = F"{args.grid_path}_{config.worker_index}.txt"
args.prism_path = F"{args.prism_path}_{config.worker_index}.prism"
shielding = config.get("shielding", False)
# if shielding:
# assert(False)
shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
env = gym.make(name)
env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding)
env = OneHotShieldingWrapper(env,
config.vector_index if hasattr(config, "vector_index") else 0,
framestack=framestack
)
return env
def register_minigrid_shielding_env(args):
env_name = "mini-grid-shielding"
register_env(env_name, shielding_env_creater)
ModelCatalog.register_custom_model(
"shielding_model",
TorchActionMaskModel
)
def ppo(args):
register_minigrid_shielding_env(args)
config = (PPOConfig()
.rollouts(num_rollout_workers=args.workers)
.resources(num_gpus=0)
.environment( env="mini-grid-shielding",
env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training})
.framework("torch")
.callbacks(MyCallbacks)
.evaluation(evaluation_config={
"evaluation_interval": 1,
"evaluation_duration": 10,
"evaluation_num_workers":1,
"env": "mini-grid-shielding",
"env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}})
.rl_module(_enable_rl_module_api = False)
.debugging(logger_config={
"type": UnifiedLogger,
"logdir": create_log_dir(args)
})
.training(_enable_learner_api=False ,model={
"custom_model": "shielding_model"
}))
tuner = tune.Tuner("PPO",
run_config=air.RunConfig(
stop = {"episode_reward_mean": 50},
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
storage_path=F"{create_log_dir(args)}-tuner"
),
param_space=config,)
tuner.fit()
iterations = args.iterations
print(config.to_dict())
tune.run("PPO", config=config)
# print(epsiode_reward_mean)
# writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i)
def main():
import argparse
args = parse_arguments(argparse)
ppo(args)
if __name__ == '__main__':
main()

21
examples/shields/rl/ShieldHandlers.py

@ -15,7 +15,7 @@ import os
class ShieldHandler(ABC):
def __init__(self) -> None:
pass
def create_shield(self, **kwargs):
def create_shield(self, **kwargs) -> dict:
pass
class MiniGridShieldHandler(ShieldHandler):
@ -32,7 +32,9 @@ class MiniGridShieldHandler(ShieldHandler):
def __create_prism(self):
os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
result = os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}")
assert result == 0, "Prism file could not be generated"
f = open(self.prism_path, "a")
f.write("label \"AgentIsInLava\" = AgentIsInLava;")
@ -78,4 +80,17 @@ class MiniGridShieldHandler(ShieldHandler):
self.__create_prism()
return self.__create_shield_dict()
def create_shield_query(env):
coordinates = env.env.agent_pos
view_direction = env.env.agent_dir
key_text = ""
# only support one key for now
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
return cur_pos_str

75
examples/shields/rl/Wrappers.py

@ -82,7 +82,12 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper):
class MiniGridShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset=True,
mask_actions=True):
super(MiniGridShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = Dict(
@ -95,32 +100,18 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
self.create_shield_at_reset = create_shield_at_reset
self.shield = shield_creator.create_shield(env=self.env)
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self):
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
allowed_actions = []
cur_pos_str = self.shield_query_creator(self.env)
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
allowed_actions = []
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
@ -144,7 +135,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
def reset(self, *, seed=None, options=None):
@ -175,38 +166,32 @@ class MiniGridShieldingWrapper(gym.core.Wrapper):
class MiniGridSbShieldingWrapper(gym.core.Wrapper):
def __init__(self, env, shield_creator : ShieldHandler, no_masking=False):
def __init__(self,
env,
shield_creator : ShieldHandler,
shield_query_creator,
create_shield_at_reset = True,
mask_actions=True,
):
super(MiniGridSbShieldingWrapper, self).__init__(env)
self.max_available_actions = env.action_space.n
self.observation_space = env.observation_space.spaces["image"]
self.shield_creator = shield_creator
self.no_masking = no_masking
self.mask_actions = mask_actions
self.shield_query_creator = shield_query_creator
def create_action_mask(self):
if self.no_masking:
if not self.mask_actions:
return np.array([1.0] * self.max_available_actions, dtype=np.int8)
coordinates = self.env.agent_pos
view_direction = self.env.agent_dir
key_text = ""
# only support one key for now
if self.keys:
key_text = F"!Agent_has_{self.keys[0]}_key\t& "
if self.env.carrying and self.env.carrying.type == "key":
key_text = F"Agent_has_{self.env.carrying.color}_key\t& "
#print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ")
cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]"
cur_pos_str = self.shield_query_creator(self.env)
allowed_actions = []
# Create the mask
# If shield restricts action mask only valid with 1.0
# else set all actions as valid
# If shield restricts actions, mask only valid actions with 1.0
# else set all actions valid
mask = np.array([0.0] * self.max_available_actions, dtype=np.int8)
if cur_pos_str in self.shield and self.shield[cur_pos_str]:
@ -215,24 +200,20 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
index = get_action_index_mapping(allowed_action[1])
if index is None:
assert(False)
mask[index] = 1.0
else:
# print(F"Not in shield {cur_pos_str}")
for index, x in enumerate(mask):
mask[index] = 1.0
front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1])
# if front_tile is not None and front_tile.type == "key":
# mask[Actions.pickup] = 1.0
# if self.env.carrying:
# mask[Actions.drop] = 1.0
if front_tile and front_tile.type == "door":
mask[Actions.toggle] = 1.0
return mask
return mask
def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset(seed=seed, options=options)
@ -245,7 +226,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper):
return obs["image"], infos
def step(self, action):
orig_obs, rew, done, truncated, info = self.env.step(action)
orig_obs, rew, done, truncated, info = self.env.step(action)
obs = orig_obs["image"]
return obs, rew, done, truncated, info

61
examples/shields/rl/callbacks.py

@ -0,0 +1,61 @@
from typing import Dict
from ray.rllib.policy import Policy
from ray.rllib.utils.typing import PolicyID
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.env.base_env import BaseEnv
from ray.rllib.evaluation import RolloutWorker
from ray.rllib.evaluation.episode import Episode
from ray.rllib.evaluation.episode_v2 import EpisodeV2
from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks
class MyCallbacks(DefaultCallbacks):
def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode started Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
episode.user_data["count"] = 0
episode.user_data["ran_into_lava"] = []
episode.user_data["goals_reached"] = []
episode.hist_data["ran_into_lava"] = []
episode.hist_data["goals_reached"] = []
# print("On episode start print")
# print(env.printGrid())
# print(worker)
# print(env.action_space.n)
# print(env.actions)
# print(env.mission)
# print(env.observation_space)
# img = env.get_frame()
# plt.imshow(img)
# plt.show()
def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None:
episode.user_data["count"] = episode.user_data["count"] + 1
env = base_env.get_sub_environments()[0]
# print(env.printGrid())
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None:
# print(F"Epsiode end Environment: {base_env.get_sub_environments()}")
env = base_env.get_sub_environments()[0]
agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1])
episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal")
episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava")
episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal"
episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava"
#print("On episode end print")
#print(env.printGrid())
episode.hist_data["goals_reached"] = episode.user_data["goals_reached"]
episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"]
def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None:
print("Evaluate Start")
def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None:
print("Evaluate End")

6
examples/shields/rl/helpers.py

@ -20,7 +20,7 @@ class ShieldingConfig(Enum):
Training = 'training'
Evaluation = 'evaluation'
Disabled = 'none'
Enabled = 'full'
Full = 'full'
def __str__(self) -> str:
return self.value
@ -39,7 +39,7 @@ def extract_keys(env):
return keys
def create_log_dir(args):
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}"
return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}-iterations:{args.iterations}"
def get_action_index_mapping(actions):
@ -93,7 +93,7 @@ def parse_arguments(argparse):
parser.add_argument("--iterations", type=int, default=30 )
parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]"
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled)
parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full)
args = parser.parse_args()

Loading…
Cancel
Save