You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
2.3 KiB

  1. from sb3_contrib import MaskablePPO
  2. from sb3_contrib.common.maskable.evaluation import evaluate_policy
  3. from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
  4. from sb3_contrib.common.wrappers import ActionMasker
  5. from stable_baselines3.common.callbacks import BaseCallback
  6. import gymnasium as gym
  7. from minigrid.core.actions import Actions
  8. import time
  9. from helpers import parse_arguments, create_log_dir, ShieldingConfig
  10. from shieldhandlers import MiniGridShieldHandler, create_shield_query
  11. from wrappers import MiniGridSbShieldingWrapper
  12. class CustomCallback(BaseCallback):
  13. def __init__(self, verbose: int = 0, env=None):
  14. super(CustomCallback, self).__init__(verbose)
  15. self.env = env
  16. def _on_step(self) -> bool:
  17. print(self.env.printGrid())
  18. return super()._on_step()
  19. def mask_fn(env: gym.Env):
  20. return env.create_action_mask()
  21. def main():
  22. import argparse
  23. args = parse_arguments(argparse)
  24. args.grid_path = F"{args.grid_path}.txt"
  25. args.prism_path = F"{args.prism_path}.prism"
  26. shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
  27. env = gym.make(args.env, render_mode="rgb_array")
  28. env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
  29. env = ActionMasker(env, mask_fn)
  30. callback = CustomCallback(1, env)
  31. model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
  32. steps = args.steps
  33. model.learn(steps, callback=callback)
  34. #W mean_reward, std_reward = evaluate_policy(model, model.get_env())
  35. vec_env = model.get_env()
  36. obs = vec_env.reset()
  37. terminated = truncated = False
  38. while not terminated and not truncated:
  39. action_masks = None
  40. action, _states = model.predict(obs, action_masks=action_masks)
  41. obs, reward, terminated, truncated, info = env.step(action)
  42. # action, _states = model.predict(obs, deterministic=True)
  43. # obs, rewards, dones, info = vec_env.step(action)
  44. vec_env.render("human")
  45. time.sleep(0.2)
  46. if __name__ == '__main__':
  47. main()