Browse Source

enabled evaluation

refactoring
sp 5 months ago
parent
commit
83076beb95
  1. 7
      examples/shields/rl/13_minigridsb.py

7
examples/shields/rl/13_minigridsb.py

@ -69,8 +69,9 @@ def main():
# Evaluation # Evaluation
eval_freq=max(500, int(args.steps/30))
n_eval_episodes=5
#eval_freq=max(500, int(args.steps/30))
eval_freq=10000
n_eval_episodes=25
render_freq = eval_freq render_freq = eval_freq
if shielded_evaluation(args.shielding): if shielded_evaluation(args.shielding):
from sb3_contrib.common.maskable.evaluation import evaluate_policy from sb3_contrib.common.maskable.evaluation import evaluate_policy
@ -86,7 +87,7 @@ def main():
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0) imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback()])
model.learn(steps,callback=[evalCallback, imageAndVideoCallback, InfoCallback()])
model.save(f"{log_dir}/{expname(args)}") model.save(f"{log_dir}/{expname(args)}")

Loading…
Cancel
Save