From 138d917fd6684a690fc2e6fd8b5afc70518617c2 Mon Sep 17 00:00:00 2001 From: Thomas Knoll Date: Thu, 7 Sep 2023 15:31:46 +0200 Subject: [PATCH] added tune example refactored and evaluation logging --- examples/shields/rl/11_minigridrl.py | 48 ++------- examples/shields/rl/13_minigridsb.py | 9 +- examples/shields/rl/14_train_eval.py | 44 ++++++-- examples/shields/rl/15_train_eval_tune.py | 118 ++++++++++++++++++++++ examples/shields/rl/ShieldHandlers.py | 21 +++- examples/shields/rl/Wrappers.py | 75 +++++--------- examples/shields/rl/callbacks.py | 61 +++++++++++ examples/shields/rl/helpers.py | 6 +- 8 files changed, 272 insertions(+), 110 deletions(-) create mode 100644 examples/shields/rl/15_train_eval_tune.py create mode 100644 examples/shields/rl/callbacks.py diff --git a/examples/shields/rl/11_minigridrl.py b/examples/shields/rl/11_minigridrl.py index 9f25be6..d792c06 100644 --- a/examples/shields/rl/11_minigridrl.py +++ b/examples/shields/rl/11_minigridrl.py @@ -1,10 +1,4 @@ -from typing import Dict -from ray.rllib.env.base_env import BaseEnv -from ray.rllib.evaluation import RolloutWorker -from ray.rllib.evaluation.episode import Episode -from ray.rllib.evaluation.episode_v2 import EpisodeV2 -from ray.rllib.policy import Policy -from ray.rllib.utils.typing import PolicyID + import gymnasium as gym @@ -15,7 +9,6 @@ import minigrid from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig -from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.tune.logger import pretty_print from ray.rllib.models import ModelCatalog @@ -23,42 +16,13 @@ from ray.rllib.models import ModelCatalog from TorchActionMaskModel import TorchActionMaskModel from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from helpers import parse_arguments, create_log_dir, ShieldingConfig -from ShieldHandlers import MiniGridShieldHandler +from ShieldHandlers import MiniGridShieldHandler, create_shield_query +from callbacks import MyCallbacks import matplotlib.pyplot as plt from ray.tune.logger import TBXLogger -class MyCallbacks(DefaultCallbacks): - def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - episode.user_data["count"] = 0 - # print("On episode start print") - # print(env.printGrid()) - # print(worker) - # print(env.action_space.n) - # print(env.actions) - # print(env.mission) - # print(env.observation_space) - # img = env.get_frame() - # plt.imshow(img) - # plt.show() - - - def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: - episode.user_data["count"] = episode.user_data["count"] + 1 - env = base_env.get_sub_environments()[0] - # print(env.printGrid()) - - def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: - # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") - env = base_env.get_sub_environments()[0] - #print("On episode end print") - #print(env.printGrid()) - - - def shielding_env_creater(config): name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") framestack = config.get("framestack", 4) @@ -69,7 +33,7 @@ def shielding_env_creater(config): shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query) # env = minigrid.wrappers.ImgObsWrapper(env) # env = ImgObsWrapper(env) env = OneHotShieldingWrapper(env, @@ -98,7 +62,7 @@ def ppo(args): config = (PPOConfig() .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) - .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training}) + .environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .framework("torch") .callbacks(MyCallbacks) .rl_module(_enable_rl_module_api = False) @@ -132,7 +96,7 @@ def dqn(args): config = config.rollouts(num_rollout_workers=args.workers) config = config.environment(env="mini-grid-shielding", env_config={"name": args.env, "args": args }) config = config.framework("torch") - #config = config.callbacks(MyCallbacks) + config = config.callbacks(MyCallbacks) config = config.rl_module(_enable_rl_module_api = False) config = config.debugging(logger_config={ "type": TBXLogger, diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 8d0c86f..05b2104 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -11,8 +11,8 @@ from minigrid.core.actions import Actions import numpy as np import time -from helpers import parse_arguments, extract_keys, get_action_index_mapping, create_log_dir -from ShieldHandlers import MiniGridShieldHandler +from helpers import parse_arguments, create_log_dir, ShieldingConfig +from ShieldHandlers import MiniGridShieldHandler, create_shield_query from Wrappers import MiniGridSbShieldingWrapper class CustomCallback(BaseCallback): @@ -27,6 +27,7 @@ class CustomCallback(BaseCallback): + def mask_fn(env: gym.Env): return env.create_action_mask() @@ -42,10 +43,10 @@ def main(): shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) env = gym.make(args.env, render_mode="rgb_array") - env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, no_masking=args.no_masking) + env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) callback = CustomCallback(1, env) - model = MaskablePPO(MaskableActorCriticPolicy, env, verbose=1, tensorboard_log=create_log_dir(args)) + model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) iterations = args.iterations diff --git a/examples/shields/rl/14_train_eval.py b/examples/shields/rl/14_train_eval.py index 047873e..30cf831 100644 --- a/examples/shields/rl/14_train_eval.py +++ b/examples/shields/rl/14_train_eval.py @@ -9,18 +9,20 @@ from ray.tune import register_env from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.dqn.dqn import DQNConfig # from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.tune.logger import pretty_print +from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger, CSVLogger from ray.rllib.models import ModelCatalog from TorchActionMaskModel import TorchActionMaskModel from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper from helpers import parse_arguments, create_log_dir, ShieldingConfig -from ShieldHandlers import MiniGridShieldHandler +from ShieldHandlers import MiniGridShieldHandler, create_shield_query + +from callbacks import MyCallbacks import matplotlib.pyplot as plt +from torch.utils.tensorboard import SummaryWriter -from ray.tune.logger import TBXLogger @@ -39,7 +41,7 @@ def shielding_env_creater(config): shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) env = gym.make(name) - env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, mask_actions=shielding) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) env = OneHotShieldingWrapper(env, config.vector_index if hasattr(config, "vector_index") else 0, @@ -67,16 +69,18 @@ def ppo(args): .rollouts(num_rollout_workers=args.workers) .resources(num_gpus=0) .environment( env="mini-grid-shielding", - env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Training}) + env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) .framework("torch") - .evaluation(evaluation_config={ "evaluation_interval": 1, - "evaluation_parallel_to_training": False, + .callbacks(MyCallbacks) + .evaluation(evaluation_config={ + "evaluation_interval": 1, + "evaluation_duration": 10, + "evaluation_num_workers":1, "env": "mini-grid-shielding", - "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Enabled or args.shielding is ShieldingConfig.Evaluation}}) - #.callbacks(MyCallbacks) + "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) .rl_module(_enable_rl_module_api = False) .debugging(logger_config={ - "type": TBXLogger, + "type": UnifiedLogger, "logdir": create_log_dir(args) }) .training(_enable_learner_api=False ,model={ @@ -90,17 +94,35 @@ def ppo(args): iterations = args.iterations + + for i in range(iterations): algo.train() - + if i % 5 == 0: algo.save() + writer = SummaryWriter(log_dir=F"{create_log_dir(args)}-eval") + csv_logger = CSVLogger() for i in range(iterations): eval_result = algo.evaluate() print(pretty_print(eval_result)) + print(eval_result) + # logger.on_result(eval_result) + + + evaluation = eval_result['evaluation'] + epsiode_reward_mean = evaluation['episode_reward_mean'] + episode_len_mean = evaluation['episode_len_mean'] + print(epsiode_reward_mean) + writer.add_scalar("evaluation/episode_reward_mean", epsiode_reward_mean, i) + writer.add_scalar("evaluation/episode_len_mean", episode_len_mean, i) + + + writer.close() + def main(): import argparse diff --git a/examples/shields/rl/15_train_eval_tune.py b/examples/shields/rl/15_train_eval_tune.py new file mode 100644 index 0000000..dc653aa --- /dev/null +++ b/examples/shields/rl/15_train_eval_tune.py @@ -0,0 +1,118 @@ + +import gymnasium as gym + +import minigrid +# import numpy as np + +# import ray +from ray.tune import register_env +from ray import tune, air +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.dqn.dqn import DQNConfig +# from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.tune.logger import pretty_print, TBXLogger, TBXLoggerCallback, DEFAULT_LOGGERS, UnifiedLogger +from ray.rllib.models import ModelCatalog + + +from TorchActionMaskModel import TorchActionMaskModel +from Wrappers import OneHotShieldingWrapper, MiniGridShieldingWrapper +from helpers import parse_arguments, create_log_dir, ShieldingConfig +from ShieldHandlers import MiniGridShieldHandler, create_shield_query + +from callbacks import MyCallbacks + +import matplotlib.pyplot as plt +from torch.utils.tensorboard import SummaryWriter + + + + +def shielding_env_creater(config): + name = config.get("name", "MiniGrid-LavaCrossingS9N1-v0") + framestack = config.get("framestack", 4) + args = config.get("args", None) + args.grid_path = F"{args.grid_path}_{config.worker_index}.txt" + args.prism_path = F"{args.prism_path}_{config.worker_index}.prism" + + shielding = config.get("shielding", False) + + # if shielding: + # assert(False) + + shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) + + env = gym.make(name) + env = MiniGridShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query ,mask_actions=shielding) + + env = OneHotShieldingWrapper(env, + config.vector_index if hasattr(config, "vector_index") else 0, + framestack=framestack + ) + + + return env + + +def register_minigrid_shielding_env(args): + env_name = "mini-grid-shielding" + register_env(env_name, shielding_env_creater) + + ModelCatalog.register_custom_model( + "shielding_model", + TorchActionMaskModel + ) + + +def ppo(args): + register_minigrid_shielding_env(args) + + config = (PPOConfig() + .rollouts(num_rollout_workers=args.workers) + .resources(num_gpus=0) + .environment( env="mini-grid-shielding", + env_config={"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Training}) + .framework("torch") + .callbacks(MyCallbacks) + .evaluation(evaluation_config={ + "evaluation_interval": 1, + "evaluation_duration": 10, + "evaluation_num_workers":1, + "env": "mini-grid-shielding", + "env_config": {"name": args.env, "args": args, "shielding": args.shielding is ShieldingConfig.Full or args.shielding is ShieldingConfig.Evaluation}}) + .rl_module(_enable_rl_module_api = False) + .debugging(logger_config={ + "type": UnifiedLogger, + "logdir": create_log_dir(args) + }) + .training(_enable_learner_api=False ,model={ + "custom_model": "shielding_model" + })) + + tuner = tune.Tuner("PPO", + run_config=air.RunConfig( + stop = {"episode_reward_mean": 50}, + checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True), + storage_path=F"{create_log_dir(args)}-tuner" + ), + param_space=config,) + + tuner.fit() + + iterations = args.iterations + print(config.to_dict()) + tune.run("PPO", config=config) + + # print(epsiode_reward_mean) + # writer.add_scalar("evaluation/episode_reward", epsiode_reward_mean, i) + + +def main(): + import argparse + args = parse_arguments(argparse) + + ppo(args) + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/shields/rl/ShieldHandlers.py b/examples/shields/rl/ShieldHandlers.py index 0ec7936..eda1768 100644 --- a/examples/shields/rl/ShieldHandlers.py +++ b/examples/shields/rl/ShieldHandlers.py @@ -15,7 +15,7 @@ import os class ShieldHandler(ABC): def __init__(self) -> None: pass - def create_shield(self, **kwargs): + def create_shield(self, **kwargs) -> dict: pass class MiniGridShieldHandler(ShieldHandler): @@ -32,7 +32,9 @@ class MiniGridShieldHandler(ShieldHandler): def __create_prism(self): - os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}") + result = os.system(F"{self.grid_to_prism_path} -v 'agent' -i {self.grid_file} -o {self.prism_path}") + + assert result == 0, "Prism file could not be generated" f = open(self.prism_path, "a") f.write("label \"AgentIsInLava\" = AgentIsInLava;") @@ -78,4 +80,17 @@ class MiniGridShieldHandler(ShieldHandler): self.__create_prism() return self.__create_shield_dict() - \ No newline at end of file + +def create_shield_query(env): + coordinates = env.env.agent_pos + view_direction = env.env.agent_dir + + key_text = "" + + # only support one key for now + + #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") + cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + + return cur_pos_str + \ No newline at end of file diff --git a/examples/shields/rl/Wrappers.py b/examples/shields/rl/Wrappers.py index 0e5cc05..1af6fec 100644 --- a/examples/shields/rl/Wrappers.py +++ b/examples/shields/rl/Wrappers.py @@ -82,7 +82,12 @@ class OneHotShieldingWrapper(gym.core.ObservationWrapper): class MiniGridShieldingWrapper(gym.core.Wrapper): - def __init__(self, env, shield_creator : ShieldHandler, create_shield_at_reset=True, mask_actions=True): + def __init__(self, + env, + shield_creator : ShieldHandler, + shield_query_creator, + create_shield_at_reset=True, + mask_actions=True): super(MiniGridShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = Dict( @@ -95,32 +100,18 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): self.create_shield_at_reset = create_shield_at_reset self.shield = shield_creator.create_shield(env=self.env) self.mask_actions = mask_actions + self.shield_query_creator = shield_query_creator def create_action_mask(self): if not self.mask_actions: return np.array([1.0] * self.max_available_actions, dtype=np.int8) - coordinates = self.env.agent_pos - view_direction = self.env.agent_dir - - key_text = "" - - # only support one key for now - if self.keys: - key_text = F"!Agent_has_{self.keys[0]}_key\t& " - - - if self.env.carrying and self.env.carrying.type == "key": - key_text = F"Agent_has_{self.env.carrying.color}_key\t& " - - cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" - - allowed_actions = [] - + cur_pos_str = self.shield_query_creator(self.env) # Create the mask # If shield restricts action mask only valid with 1.0 # else set all actions as valid + allowed_actions = [] mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: @@ -144,7 +135,7 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 - + return mask def reset(self, *, seed=None, options=None): @@ -175,38 +166,32 @@ class MiniGridShieldingWrapper(gym.core.Wrapper): class MiniGridSbShieldingWrapper(gym.core.Wrapper): - def __init__(self, env, shield_creator : ShieldHandler, no_masking=False): + def __init__(self, + env, + shield_creator : ShieldHandler, + shield_query_creator, + create_shield_at_reset = True, + mask_actions=True, + ): super(MiniGridSbShieldingWrapper, self).__init__(env) self.max_available_actions = env.action_space.n self.observation_space = env.observation_space.spaces["image"] self.shield_creator = shield_creator - self.no_masking = no_masking + self.mask_actions = mask_actions + self.shield_query_creator = shield_query_creator def create_action_mask(self): - if self.no_masking: + if not self.mask_actions: return np.array([1.0] * self.max_available_actions, dtype=np.int8) - coordinates = self.env.agent_pos - view_direction = self.env.agent_dir - - key_text = "" - - # only support one key for now - if self.keys: - key_text = F"!Agent_has_{self.keys[0]}_key\t& " - - - if self.env.carrying and self.env.carrying.type == "key": - key_text = F"Agent_has_{self.env.carrying.color}_key\t& " - - #print(F"Agent pos is {self.env.agent_pos} and direction {self.env.agent_dir} ") - cur_pos_str = f"[{key_text}!AgentDone\t& xAgent={coordinates[0]}\t& yAgent={coordinates[1]}\t& viewAgent={view_direction}]" + cur_pos_str = self.shield_query_creator(self.env) + allowed_actions = [] # Create the mask - # If shield restricts action mask only valid with 1.0 - # else set all actions as valid + # If shield restricts actions, mask only valid actions with 1.0 + # else set all actions valid mask = np.array([0.0] * self.max_available_actions, dtype=np.int8) if cur_pos_str in self.shield and self.shield[cur_pos_str]: @@ -215,24 +200,20 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): index = get_action_index_mapping(allowed_action[1]) if index is None: assert(False) + mask[index] = 1.0 else: - # print(F"Not in shield {cur_pos_str}") for index, x in enumerate(mask): mask[index] = 1.0 front_tile = self.env.grid.get(self.env.front_pos[0], self.env.front_pos[1]) - # if front_tile is not None and front_tile.type == "key": - # mask[Actions.pickup] = 1.0 - - # if self.env.carrying: - # mask[Actions.drop] = 1.0 if front_tile and front_tile.type == "door": mask[Actions.toggle] = 1.0 - return mask + return mask + def reset(self, *, seed=None, options=None): obs, infos = self.env.reset(seed=seed, options=options) @@ -245,7 +226,7 @@ class MiniGridSbShieldingWrapper(gym.core.Wrapper): return obs["image"], infos def step(self, action): - orig_obs, rew, done, truncated, info = self.env.step(action) + orig_obs, rew, done, truncated, info = self.env.step(action) obs = orig_obs["image"] return obs, rew, done, truncated, info diff --git a/examples/shields/rl/callbacks.py b/examples/shields/rl/callbacks.py new file mode 100644 index 0000000..199bb65 --- /dev/null +++ b/examples/shields/rl/callbacks.py @@ -0,0 +1,61 @@ + +from typing import Dict + +from ray.rllib.policy import Policy +from ray.rllib.utils.typing import PolicyID + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.env.base_env import BaseEnv +from ray.rllib.evaluation import RolloutWorker +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.episode_v2 import EpisodeV2 + +from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks + +class MyCallbacks(DefaultCallbacks): + def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode started Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + episode.user_data["count"] = 0 + episode.user_data["ran_into_lava"] = [] + episode.user_data["goals_reached"] = [] + episode.hist_data["ran_into_lava"] = [] + episode.hist_data["goals_reached"] = [] + # print("On episode start print") + # print(env.printGrid()) + # print(worker) + # print(env.action_space.n) + # print(env.actions) + # print(env.mission) + # print(env.observation_space) + # img = env.get_frame() + # plt.imshow(img) + # plt.show() + + + def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy] | None = None, episode: Episode | EpisodeV2, env_index: int | None = None, **kwargs) -> None: + episode.user_data["count"] = episode.user_data["count"] + 1 + env = base_env.get_sub_environments()[0] + # print(env.printGrid()) + + def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], episode: Episode | EpisodeV2 | Exception, env_index: int | None = None, **kwargs) -> None: + # print(F"Epsiode end Environment: {base_env.get_sub_environments()}") + env = base_env.get_sub_environments()[0] + agent_tile = env.grid.get(env.agent_pos[0], env.agent_pos[1]) + + episode.user_data["goals_reached"].append(agent_tile is not None and agent_tile.type == "goal") + episode.user_data["ran_into_lava"].append(agent_tile is not None and agent_tile.type == "lava") + episode.custom_metrics["reached_goal"] = agent_tile is not None and agent_tile.type == "goal" + episode.custom_metrics["ran_into_lava"] = agent_tile is not None and agent_tile.type == "lava" + #print("On episode end print") + #print(env.printGrid()) + episode.hist_data["goals_reached"] = episode.user_data["goals_reached"] + episode.hist_data["ran_into_lava"] = episode.user_data["ran_into_lava"] + + + def on_evaluate_start(self, *, algorithm: Algorithm, **kwargs) -> None: + print("Evaluate Start") + + def on_evaluate_end(self, *, algorithm: Algorithm, evaluation_metrics: dict, **kwargs) -> None: + print("Evaluate End") + diff --git a/examples/shields/rl/helpers.py b/examples/shields/rl/helpers.py index 5863d99..f906abf 100644 --- a/examples/shields/rl/helpers.py +++ b/examples/shields/rl/helpers.py @@ -20,7 +20,7 @@ class ShieldingConfig(Enum): Training = 'training' Evaluation = 'evaluation' Disabled = 'none' - Enabled = 'full' + Full = 'full' def __str__(self) -> str: return self.value @@ -39,7 +39,7 @@ def extract_keys(env): return keys def create_log_dir(args): - return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}" + return F"{args.log_dir}{datetime.now()}-{args.algorithm}-shielding:{args.shielding}-env:{args.env}-iterations:{args.iterations}" def get_action_index_mapping(actions): @@ -93,7 +93,7 @@ def parse_arguments(argparse): parser.add_argument("--iterations", type=int, default=30 ) parser.add_argument("--formula", default="Pmax=? [G !\"AgentIsInLavaAndNotDone\"]") # formula_str = "Pmax=? [G ! \"AgentIsInGoalAndNotDone\"]" parser.add_argument("--workers", type=int, default=1) - parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Enabled) + parser.add_argument("--shielding", type=ShieldingConfig, choices=list(ShieldingConfig), default=ShieldingConfig.Full) args = parser.parse_args()