diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index c5364b9..0fab627 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -1,19 +1,20 @@ from sb3_contrib import MaskablePPO from sb3_contrib.common.maskable.evaluation import evaluate_policy -from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy from sb3_contrib.common.wrappers import ActionMasker import gymnasium as gym from minigrid.core.actions import Actions +from minigrid.wrappers import RGBImgObsWrapper, ImgObsWrapper import time -from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig -from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments +from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig, MiniWrapper +from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments, ImageRecorderCallback, InfoCallback -GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" +import os +GRID_TO_PRISM_BINARY=os.getenv("M2P_BINARY") def mask_fn(env: gym.Env): return env.create_action_mask() @@ -27,14 +28,16 @@ def main(): shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) env = gym.make(args.env, render_mode="rgb_array") + env = RGBImgObsWrapper(env) # Get pixel observations + env = ImgObsWrapper(env) # Get rid of the 'mission' field + env = MiniWrapper(env) env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) - model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) + model = MaskablePPO("CnnPolicy", env, verbose=1, tensorboard_log=create_log_dir(args)) steps = args.steps - - model.learn(steps) + model.learn(steps,callback=[ImageRecorderCallback(), InfoCallback()], log_interval=1) print("Learning done, hit enter")