|
@ -1,5 +1,5 @@ |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib import MaskablePPO |
|
|
from sb3_contrib.common.maskable.evaluation import evaluate_policy |
|
|
|
|
|
|
|
|
from sb3_contrib.common.maskable.callbacks import MaskableEvalCallback |
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
from sb3_contrib.common.wrappers import ActionMasker |
|
|
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat |
|
|
from stable_baselines3.common.logger import Logger, CSVOutputFormat, TensorBoardOutputFormat, HumanOutputFormat |
|
|
|
|
|
|
|
@ -10,7 +10,7 @@ from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper |
|
|
|
|
|
|
|
|
import time |
|
|
import time |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname |
|
|
|
|
|
|
|
|
from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper, expname, shield_needed, shielded_evaluation |
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback |
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
|
|
|
|
|
@ -35,8 +35,11 @@ def main(): |
|
|
log_dir = create_log_dir(args) |
|
|
log_dir = create_log_dir(args) |
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |
|
|
|
|
|
|
|
|
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training or args.shielding == ShieldingConfig.Evaluation: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if shield_needed(args.shielding): |
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) |
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
env = RGBImgObsWrapper(env) |
|
|
env = RGBImgObsWrapper(env) |
|
|
env = ImgObsWrapper(env) |
|
|
env = ImgObsWrapper(env) |
|
@ -63,11 +66,19 @@ def main(): |
|
|
assert(False) # TODO Do something proper |
|
|
assert(False) # TODO Do something proper |
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") |
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") |
|
|
model.set_logger(new_logger) |
|
|
model.set_logger(new_logger) |
|
|
|
|
|
steps = args.steps |
|
|
|
|
|
|
|
|
|
|
|
eval_freq=max(500, int(args.steps/30)) |
|
|
|
|
|
n_eval_episodes=5 |
|
|
|
|
|
if shielded_evaluation(args.shielding): |
|
|
|
|
|
evalCallback = MaskableEvalCallback(eval_env, best_model_save_path=log_dir, |
|
|
|
|
|
log_path=log_dir, eval_freq=eval_freq, |
|
|
|
|
|
deterministic=True, render=False, n_eval_episodes=n_eval_episodes) |
|
|
|
|
|
else: |
|
|
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, |
|
|
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, |
|
|
log_path=log_dir, eval_freq=max(500, int(args.steps/30)), |
|
|
|
|
|
deterministic=True, render=False, n_eval_episodes=5) |
|
|
|
|
|
steps = args.steps |
|
|
|
|
|
|
|
|
log_path=log_dir, eval_freq=eval_freq, |
|
|
|
|
|
deterministic=True, render=False, n_eval_episodes=n_eval_episodes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) |
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) |
|
|
|
|
|
|
|
|