You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

54 lines
1.8 KiB

2 months ago
  1. from sb3_contrib import MaskablePPO
  2. from sb3_contrib.common.maskable.evaluation import evaluate_policy
  3. from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
  4. from sb3_contrib.common.wrappers import ActionMasker
  5. import gymnasium as gym
  6. from minigrid.core.actions import Actions
  7. import time
  8. from utils import MiniGridShieldHandler, create_shield_query, parse_arguments, create_log_dir, ShieldingConfig
  9. from sb3utils import MiniGridSbShieldingWrapper
  10. def mask_fn(env: gym.Env):
  11. return env.create_action_mask()
  12. def main():
  13. import argparse
  14. args = parse_arguments(argparse)
  15. args.grid_path = F"{args.grid_path}.txt"
  16. args.prism_path = F"{args.prism_path}.prism"
  17. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  18. env = gym.make(args.env, render_mode="rgb_array")
  19. env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
  20. env = ActionMasker(env, mask_fn)
  21. model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
  22. steps = args.steps
  23. model.learn(steps)
  24. #W mean_reward, std_reward = evaluate_policy(model, model.get_env())
  25. vec_env = model.get_env()
  26. obs = vec_env.reset()
  27. terminated = truncated = False
  28. while not terminated and not truncated:
  29. action_masks = None
  30. action, _states = model.predict(obs, action_masks=action_masks)
  31. obs, reward, terminated, truncated, info = env.step(action)
  32. # action, _states = model.predict(obs, deterministic=True)
  33. # obs, rewards, dones, info = vec_env.step(action)
  34. vec_env.render("human")
  35. time.sleep(0.2)
  36. if __name__ == '__main__':
  37. main()