You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

68 lines
2.3 KiB

from sb3_contrib import MaskablePPO
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.logger import configure
import gymnasium as gym
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
from utils import ShieldingConfig, MiniWrapper, create_shield_overlay_image
from minigrid_shield_handler import MiniGridShieldHandler
from sb3utils import MiniGridSbShieldingWrapper, InfoCallback, parse_sb3_arguments
import os
import datetime
from PIL import Image
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
def mask_fn(env: gym.Env):
return env.create_action_mask()
def nomask_fn(env: gym.Env):
return [1.0] * 4
def main(env_name, seed=None):
formula = ["Pmin=? [F<=2 (","AgentIsOnLava", ")]"]
shield_value = 0.01
shield_comparison = "absolute"
log_path = f"./training_results/{args.shielding}_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{env_name}"
new_logger = configure(log_path, ["stdout", "csv"])
if seed:
env = gym.make(env_name, render_mode="rgb_array", seed=seed)
else:
env = gym.make(env_name, render_mode="rgb_array")
env.reset()
env = RGBImgObsWrapper(env)
env = ImgObsWrapper(env)
env = MiniWrapper(env)
img = Image.fromarray(env.render())
img.save("/opt/workspace/env.png")
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, env, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=False, prism_file=args.prism_file, ignore_view=True)
if args.shielding == ShieldingConfig.Full:
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
env = ActionMasker(env, mask_fn)
img = create_shield_overlay_image(env, shield_handler.action_dictionary, shield_handler.dangerous_states)
img.save("/opt/workspace/env_and_shield.png")
elif args.shielding == ShieldingConfig.Disabled:
env = ActionMasker(env, nomask_fn)
else:
assert False
model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto")
model.set_logger(new_logger)
steps = 10000
model.learn(steps,callback=[InfoCallback()])
if __name__ == '__main__':
args = parse_sb3_arguments()
main(args.env)