from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3.common.callbacks import BaseCallback

import gymnasium as gym

from minigrid.core.actions import Actions

import time

from helpers import parse_arguments, create_log_dir, ShieldingConfig
from shieldhandlers import MiniGridShieldHandler, create_shield_query
from wrappers import MiniGridSbShieldingWrapper

class CustomCallback(BaseCallback):
    def __init__(self, verbose: int = 0, env=None):
        super(CustomCallback, self).__init__(verbose)
        self.env = env
        
        
    def _on_step(self) -> bool:
        print(self.env.printGrid())
        return super()._on_step()




def mask_fn(env: gym.Env):
    return env.create_action_mask()
    


def main():
    import argparse
    args = parse_arguments(argparse)
    
    args.grid_path = F"{args.grid_path}.txt"
    args.prism_path = F"{args.prism_path}.prism"
    
    shield_creator = MiniGridShieldHandler(args.grid_path, args.grid_to_prism_binary_path, args.prism_path, args.formula)
    
    env = gym.make(args.env, render_mode="rgb_array")
    env = MiniGridSbShieldingWrapper(env, shield_creator=shield_creator, shield_query_creator=create_shield_query, mask_actions=args.shielding == ShieldingConfig.Full)
    env = ActionMasker(env, mask_fn)
    callback = CustomCallback(1, env)
    model = MaskablePPO(MaskableActorCriticPolicy, env, gamma=0.4, verbose=1, tensorboard_log=create_log_dir(args))
    
    steps = args.steps
    
    
    model.learn(steps, callback=callback)
 
  #W  mean_reward, std_reward = evaluate_policy(model, model.get_env())
    
    vec_env = model.get_env()
    obs = vec_env.reset()
    terminated = truncated = False
    while not terminated and not truncated:
        action_masks = None
        action, _states = model.predict(obs, action_masks=action_masks)
        obs, reward, terminated, truncated, info = env.step(action)
        # action, _states = model.predict(obs, deterministic=True)
        # obs, rewards, dones, info = vec_env.step(action)
        vec_env.render("human")
        time.sleep(0.2)
    
    

if __name__ == '__main__':
    main()