Browse Source

changed some callbacks

refactoring
sp 10 months ago
parent
commit
b9ed2ac234
  1. 6
      examples/shields/rl/13_minigridsb.py
  2. 4
      examples/shields/rl/sb3utils.py

6
examples/shields/rl/13_minigridsb.py

@ -33,8 +33,8 @@ def main():
shield_value = args.shield_value shield_value = args.shield_value
shield_comparison = args.shield_comparison shield_comparison = args.shield_comparison
log_dir = create_log_dir(args) log_dir = create_log_dir(args)
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
#new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)])
#new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)])
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)])
if shield_needed(args.shielding): if shield_needed(args.shielding):
@ -89,7 +89,7 @@ def main():
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()])
#vec_env = model.get_env() #vec_env = model.get_env()
#obs = vec_env.reset() #obs = vec_env.reset()

4
examples/shields/rl/sb3utils.py

@ -61,8 +61,8 @@ class ImageRecorderCallback(BaseCallback):
self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv")) self.logger.record("trajectory/image", Image(image, "HWC"), exclude=("stdout", "log", "json", "csv"))
def _on_step(self) -> bool: def _on_step(self) -> bool:
if self.n_calls % self._render_freq == 0:
self.record_video()
#if self.n_calls % self._render_freq == 0:
# self.record_video()
return True return True
def _on_training_end(self) -> None: def _on_training_end(self) -> None:

Loading…
Cancel
Save