|
|
@ -1,7 +1,7 @@ |
|
|
|
from sb3_contrib import MaskablePPO |
|
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
|
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat |
|
|
|
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat |
|
|
|
|
|
|
|
import gymnasium as gym |
|
|
|
|
|
|
@ -14,7 +14,7 @@ from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWr |
|
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
|
|
|
|
|
import os |
|
|
|
import os, sys |
|
|
|
|
|
|
|
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") |
|
|
|
def mask_fn(env: gym.Env): |
|
|
@ -31,7 +31,7 @@ def main(): |
|
|
|
shield_value = args.shield_value |
|
|
|
shield_comparison = args.shield_comparison |
|
|
|
log_dir = create_log_dir(args) |
|
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |
|
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) |
|
|
|
|
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
|
env = RGBImgObsWrapper(env) |
|
|
|