|
@ -1,7 +1,7 @@ |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
from stable_baselines3.common.logger import configure |
|
|
|
|
|
|
|
|
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat |
|
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
import gymnasium as gym |
|
|
|
|
|
|
|
@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper |
|
|
|
|
|
|
|
|
import time |
|
|
import time |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname |
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
|
|
|
|
|
@ -31,7 +31,7 @@ def main(): |
|
|
shield_value = args.shield_value |
|
|
shield_value = args.shield_value |
|
|
shield_comparison = args.shield_comparison |
|
|
shield_comparison = args.shield_comparison |
|
|
log_dir = create_log_dir(args) |
|
|
log_dir = create_log_dir(args) |
|
|
new_logger = configure(log_dir, ["csv", "tensorboard"]) |
|
|
|
|
|
|
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |
|
|
|
|
|
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
env = RGBImgObsWrapper(env) |
|
|
env = RGBImgObsWrapper(env) |
|
|