from sb3_contrib import MaskablePPO from sb3_contrib.common.wrappers import ActionMasker from stable_baselines3.common.logger import configure import gymnasium as gym from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper from utils import ShieldingConfig, MiniWrapper, create_shield_overlay_image from minigrid_shield_handler import MiniGridShieldHandler from sb3utils import MiniGridSbShieldingWrapper, InfoCallback, parse_sb3_arguments import os import datetime from PIL import Image GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") def mask_fn(env: gym.Env): return env.create_action_mask() def nomask_fn(env: gym.Env): return [1.0] * 4 def main(env_name, seed=None): formula = ["Pmin=? [F<=2 (","AgentIsOnLava", ")]"] shield_value = 0.01 shield_comparison = "absolute" log_path = f"./training_results/{args.shielding}_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{env_name}" new_logger = configure(log_path, ["stdout", "csv"]) if seed: env = gym.make(env_name, render_mode="rgb_array", seed=seed) else: env = gym.make(env_name, render_mode="rgb_array") env.reset() env = RGBImgObsWrapper(env) env = ImgObsWrapper(env) env = MiniWrapper(env) img = Image.fromarray(env.render()) img.save("/opt/workspace/env.png") shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, env, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=False, prism_file=args.prism_file, ignore_view=True) if args.shielding == ShieldingConfig.Full: env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False) env = ActionMasker(env, mask_fn) img = create_shield_overlay_image(env, shield_handler.action_dictionary, shield_handler.dangerous_states) img.save("/opt/workspace/env_and_shield.png") elif args.shielding == ShieldingConfig.Disabled: env = ActionMasker(env, nomask_fn) else: assert False model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto") model.set_logger(new_logger) steps = 10000 model.learn(steps,callback=[InfoCallback()]) if __name__ == '__main__': args = parse_sb3_arguments() main(args.env)