|
|
@ -1,6 +1,7 @@ |
|
|
|
from sb3_contrib import MaskablePPO |
|
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
|
from stable_baselines3.common.logger import configure |
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
|
|
|
|
@ -30,6 +31,7 @@ def main(): |
|
|
|
shield_value = args.shield_value |
|
|
|
shield_comparison = args.shield_comparison |
|
|
|
log_dir = create_log_dir(args) |
|
|
|
new_logger = configure(log_dir, ["csv", "tensorboard"]) |
|
|
|
|
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
|
env = RGBImgObsWrapper(env) |
|
|
@ -42,6 +44,7 @@ def main(): |
|
|
|
else: |
|
|
|
env = ActionMasker(env, nomask_fn) |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") |
|
|
|
model.set_logger(new_logger) |
|
|
|
|
|
|
|
|
|
|
|
evalCallback = EvalCallback(env, best_model_save_path=log_dir, |
|
|
|