|
|
@ -15,6 +15,7 @@ from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecor |
|
|
|
from stable_baselines3.common.callbacks import EvalCallback |
|
|
|
|
|
|
|
import os, sys |
|
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") |
|
|
|
def mask_fn(env: gym.Env): |
|
|
@ -27,32 +28,48 @@ def nomask_fn(env: gym.Env): |
|
|
|
def main(): |
|
|
|
args = parse_sb3_arguments() |
|
|
|
|
|
|
|
|
|
|
|
formula = args.formula |
|
|
|
shield_value = args.shield_value |
|
|
|
shield_comparison = args.shield_comparison |
|
|
|
log_dir = create_log_dir(args) |
|
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir), HumanOutputFormat(sys.stdout)]) |
|
|
|
new_logger = Logger(log_dir, output_formats=[CSVOutputFormat(os.path.join(log_dir, f"progress_{expname(args)}.csv")), TensorBoardOutputFormat(log_dir)]) |
|
|
|
|
|
|
|
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training or args.shielding == ShieldingConfig.Evaluation: |
|
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) |
|
|
|
env = gym.make(args.env, render_mode="rgb_array") |
|
|
|
env = RGBImgObsWrapper(env) |
|
|
|
env = ImgObsWrapper(env) |
|
|
|
env = MiniWrapper(env) |
|
|
|
if args.shielding == ShieldingConfig.Full or args.shielding == ShieldingConfig.Training: |
|
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=args.nocleanup) |
|
|
|
eval_env = deepcopy(env) |
|
|
|
eval_env.disable_random_start() |
|
|
|
if args.shielding == ShieldingConfig.Full: |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
else: |
|
|
|
eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False) |
|
|
|
eval_env = ActionMasker(eval_env, mask_fn) |
|
|
|
elif args.shielding == ShieldingConfig.Training: |
|
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) |
|
|
|
env = ActionMasker(env, mask_fn) |
|
|
|
eval_env = ActionMasker(eval_env, nomask_fn) |
|
|
|
elif args.shielding == ShieldingConfig.Evaluation: |
|
|
|
env = ActionMasker(env, nomask_fn) |
|
|
|
eval_env = MiniGridSbShieldingWrapper(eval_env, shield_handler=shield_handler, create_shield_at_reset=False) |
|
|
|
eval_env = ActionMasker(eval_env, mask_fn) |
|
|
|
elif args.shielding == ShieldingConfig.Disabled: |
|
|
|
env = ActionMasker(env, nomask_fn) |
|
|
|
eval_env = ActionMasker(eval_env, nomask_fn) |
|
|
|
else: |
|
|
|
assert(False) # TODO Do something proper |
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=log_dir, device="auto") |
|
|
|
model.set_logger(new_logger) |
|
|
|
|
|
|
|
|
|
|
|
evalCallback = EvalCallback(env, best_model_save_path=log_dir, |
|
|
|
evalCallback = EvalCallback(eval_env, best_model_save_path=log_dir, |
|
|
|
log_path=log_dir, eval_freq=max(500, int(args.steps/30)), |
|
|
|
deterministic=True, render=False) |
|
|
|
deterministic=True, render=False, n_eval_episodes=5) |
|
|
|
steps = args.steps |
|
|
|
|
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()]) |
|
|
|
model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback(), evalCallback]) |
|
|
|
|
|
|
|
#vec_env = model.get_env() |
|
|
|
#obs = vec_env.reset() |
|
|
|