diff --git a/examples/shields/rl/13_minigridsb.py b/examples/shields/rl/13_minigridsb.py index f04c65e..c5364b9 100644 --- a/examples/shields/rl/13_minigridsb.py +++ b/examples/shields/rl/13_minigridsb.py @@ -9,47 +9,50 @@ from minigrid.core.actions import Actions import time -from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig -from sb3utils import MiniGridSbShieldingWrapper +from utils import MiniGridShieldHandler, create_log_dir, ShieldingConfig +from sb3utils import MiniGridSbShieldingWrapper, parse_sb3_arguments + +GRID_TO_PRISM_BINARY="/home/spranger/research/tempestpy/Minigrid2PRISM/build/main" def mask_fn(env: gym.Env): return env.create_action_mask() - + def main(): - import argparse - args = parse_arguments(argparse) - - args.grid_path = F"{args.grid_path}.txt" - args.prism_path = F"{args.prism_path}.prism" - - shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula) - + args = parse_sb3_arguments() + + formula = args.formula + shield_value = args.shield_value + shield_comparison = args.shield_comparison + + shield_handler = MiniGridShieldHandler(GRID_TO_PRISM_BINARY, args.grid_file, args.prism_output_file, formula, shield_value=shield_value, shield_comparison=shield_comparison) env = gym.make(args.env, render_mode="rgb_array") - env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full) + env = MiniGridSbShieldingWrapper(env, shield_handler=shield_handler, create_shield_at_reset=False, mask_actions=args.shielding == ShieldingConfig.Full) env = ActionMasker(env, mask_fn) model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args)) - + steps = args.steps - - + + model.learn(steps) - - #W mean_reward, std_reward = evaluate_policy(model, model.get_env()) - + + + print("Learning done, hit enter") + input("") vec_env = model.get_env() obs = vec_env.reset() terminated = truncated = False while not terminated and not truncated: action_masks = None action, _states = model.predict(obs, action_masks=action_masks) + print(action) obs, reward, terminated, truncated, info = env.step(action) # action, _states = model.predict(obs, deterministic=True) # obs, rewards, dones, info = vec_env.step(action) vec_env.render("human") time.sleep(0.2) - - + + if __name__ == '__main__': - main() \ No newline at end of file + main()