From b9ed2ac23471a9c1ddd4db725d3b74f564860478 Mon Sep 17 00:00:00 2001 From: sp Date: Thu, 18 Jan 2024 22:02:42 +0100 Subject: [PATCH] changed some callbacks --- examples/shields/rl/13_minigridsb.py | 6 +++--- examples/shields/rl/sb3utils.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index 175dc8e..07bfaa1 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -33,8 +33,8 @@ def main(): shield_value = args.shield_value shield_comparison = args.shield_comparison log_dir = create_log_dir(args) - new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) - #new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) + #new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) + new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) if shield_needed(args.shielding): @@ -89,7 +89,7 @@ def main(): imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) - model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback]) + model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()]) #vec_env = model.get_env() #obs = vec_env.reset() diff --git a/examples/shields/rl/sb3utils.py b/examples/shields/rl/sb3utils.py index 4aff45c..c87f719 100644 --- a/examples/shields/rl/sb3utils.py +++ b/examples/shields/rl/sb3utils.py @@ -61,8 +61,8 @@ class ImageRecorderCallback(BaseCallback): self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) def _on_step(self) -> bool: - if self.n_calls % self._render_freq == 0: - self.record_video() + #if self.n_calls % self._render_freq == 0: + # self.record_video() return True def _on_training_end(self) -> None: