You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
68 lines
2.3 KiB
68 lines
2.3 KiB
from sb3_contrib import MaskablePPO
|
|
from sb3_contrib.common.wrappers import ActionMasker
|
|
from stable_baselines3.common.logger import configure
|
|
|
|
import gymnasium as gym
|
|
|
|
from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper
|
|
|
|
from utils import ShieldingConfig, MiniWrapper, create_shield_overlay_image
|
|
from minigrid_shield_handler import MiniGridShieldHandler
|
|
from sb3utils import MiniGridSbShieldingWrapper, InfoCallback, parse_sb3_arguments
|
|
|
|
import os
|
|
import datetime
|
|
|
|
from PIL import Image
|
|
|
|
GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY")
|
|
def mask_fn(env: gym.Env):
|
|
return env.create_action_mask()
|
|
|
|
def nomask_fn(env: gym.Env):
|
|
return [1.0] * 4
|
|
|
|
def main(env_name, seed=None):
|
|
formula = ["Pmin=? [F<=2 (","AgentIsOnLava", ")]"]
|
|
shield_value = 0.01
|
|
|
|
shield_comparison = "absolute"
|
|
log_path = f"./training_results/{args.shielding}_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}_{env_name}"
|
|
new_logger = configure(log_path, ["stdout", "csv"])
|
|
|
|
if seed:
|
|
env = gym.make(env_name, render_mode="rgb_array", seed=seed)
|
|
else:
|
|
env = gym.make(env_name, render_mode="rgb_array")
|
|
env.reset()
|
|
|
|
env = RGBImgObsWrapper(env)
|
|
env = ImgObsWrapper(env)
|
|
env = MiniWrapper(env)
|
|
|
|
img = Image.fromarray(env.render())
|
|
img.save("/opt/workspace/env.png")
|
|
|
|
|
|
shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, env, shield_value=shield_value, shield_comparison=shield_comparison, nocleanup=False, prism_file=args.prism_file, ignore_view=True)
|
|
|
|
if args.shielding == ShieldingConfig.Full:
|
|
env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False)
|
|
env = ActionMasker(env, mask_fn)
|
|
img = create_shield_overlay_image(env, shield_handler.action_dictionary, shield_handler.dangerous_states)
|
|
img.save("/opt/workspace/env_and_shield.png")
|
|
elif args.shielding == ShieldingConfig.Disabled:
|
|
env = ActionMasker(env, nomask_fn)
|
|
else:
|
|
assert False
|
|
|
|
model = MaskablePPO("CnnPolicy", env, verbose=1, device="auto")
|
|
model.set_logger(new_logger)
|
|
steps = 10000
|
|
model.learn(steps,callback=[InfoCallback()])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_sb3_arguments()
|
|
main(args.env)
|
|
|