Browse Source

reintroduced learning

refactoring
sp 10 months ago
parent
commit
dc8e4f320d
  1. 52
      examples/shields/rl/13_minigridsb.py

52
examples/shields/rl/13_minigridsb.py

@ -65,32 +65,32 @@ def main():
eval_env = ActionMasker(eval_env, nomask_fn)
else:
assert(False) # TODO Do something proper
#model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
#model.set_logger(new_logger)
#steps = args.steps
## Evaluation
#eval_freq=max(500, int(args.steps/30))
#n_eval_episodes=5
#render_freq = eval_freq
#if shielded_evaluation(args.shielding):
# from sb3_contrib.common.maskable.evaluation import evaluate_policy
# evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir,
# log_path=log_dir, eval_freq=eval_freq,
# deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
# imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
#else:
# from stable_baselines3.common.evaluation import evaluate_policy
# evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir,
# log_path=log_dir, eval_freq=eval_freq,
# deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
# imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
#model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
#model.save(f"{log_dir}/{expname(args)}")
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto")
model.set_logger(new_logger)
steps = args.steps
# Evaluation
eval_freq=max(500, int(args.steps/30))
n_eval_episodes=5
render_freq = eval_freq
if shielded_evaluation(args.shielding):
from sb3_contrib.common.maskable.evaluation import evaluate_policy
evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
else:
from stable_baselines3.common.evaluation import evaluate_policy
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir,
log_path=log_dir, eval_freq=eval_freq,
deterministic=True, render=False, n_eval_episodes=n_eval_episodes)
imageAndVideoCallback = ImageRecorderCallback(eval_env, render_freq, n_eval_episodes=1, evaluation_method=evaluate_policy, log_dir=log_dir, deterministic=True, verbose=0)
model.learn(steps,callback=[imageAndVideoCallback, InfoCallback(), evalCallback])
model.save(f"{log_dir}/{expname(args)}")
if __name__ == '__main__':

Loading…
Cancel
Save