#include <algorithm>
#include <iostream>
#include <queue>
#include <sstream>
#include <climits>

#include "Graph.h"
#include "util/GraphParser.h"
namespace data {
  Graph::Graph(bool stdout_output, bool file_output, std::string output_filename, bool verbose_max_flow, bool min_cut, int verbosity)
    : m_file_output(file_output), m_output_file_name(output_filename), m_verbose_max_flow(verbose_max_flow), m_min_cut(min_cut), m_verbosity(verbosity) {
  }

  void Graph::parseFromString(const std::string &graph_string) {
    parser::parseString(graph_string, m_arc_list, m_vertices, m_source_id, m_sink_id, m_num_vertices, m_num_arcs);
    initMatrices();
    initOstream();
  }

  void Graph::parseFromFile(const std::string &graph_file) {
    if(graph_file == m_output_file_name) {
      throw std::runtime_error("Input graph file name and output file name are the same. Will not overwrite. Exiting...");
    }
    parser::parseFile(graph_file, m_arc_list, m_vertices, m_source_id, m_sink_id, m_num_vertices, m_num_arcs);
    initMatrices();
    initOstream();
  }

  void Graph::initMatrices() {
    m_flow.resize(m_num_vertices, std::vector<Capacity>(m_num_vertices, 0));
    m_capapcities.resize(m_num_vertices, std::vector<Capacity>(m_num_vertices, 0));
    for(auto const &arc : m_arc_list) {
      m_capapcities.at(arc.start - 1).at(arc.end - 1) += arc.capacity;
    }
    m_network = m_capapcities;
  }

  void Graph::initOstream() {
    if(m_file_output) {
      m_ofstream = new std::ofstream(m_output_file_name);
    } else {
      m_ofstream = &std::cout;
    }
  }

  int Graph::maxFlowDinic()  {
    std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
    printInformation();
    do {
      constructLevelGraph();
    } while(findAugmentingPaths() != NO_AUGMENTING_PATH_FOUND);
    *m_ofstream << "Found max flow |x| = " <<  m_max_flow << "\n";
    if(m_verbose_max_flow) printMaxFlowInformation();
    if(m_min_cut) printMinCut();
    if(m_verbosity >= 1) printComputationStatistics(start, std::chrono::steady_clock::now());
    return m_max_flow;
  }

  int Graph::findAugmentingPaths() {
    auto m_sink = std::find_if(m_vertices.begin(), m_vertices.end(), [this] (const Vertex &v) { return v.getID() == m_sink_id; });
    if(m_sink->getLevel() == UNDEF_LEVEL) {
      return NO_AUGMENTING_PATH_FOUND;
    }
    for(auto &v : m_vertices) {
      v.setVisited(false);
    }
    auto m_source = std::find_if(m_vertices.begin(), m_vertices.end(), [this] (const Vertex &v) { return v.getID() == m_source_id; });
    std::vector<Vertex> path{*m_source};
    buildPath(path);
    return 0;
  }

  void Graph::buildPath(std::vector<Vertex> &current_path) {
    Vertex head = current_path.back();
    if(head.getID() == m_sink_id) {
      computeFlowForPath(current_path);
    }
    for(auto const& arc : head.getOutgoingArcs()) {
      if(m_capapcities.at(arc.start - 1).at(arc.end - 1) <= 0) continue;
      auto it = std::find_if(m_vertices.begin(), m_vertices.end(), [&arc] (const Vertex &v) { return v.getID() == arc.end; });
      if(head.getLevel() + 1 != it->getLevel()) continue;
      if(it != m_vertices.end()) {
        current_path.push_back(*it);
        buildPath(current_path);
      }
      current_path.pop_back();
    }
    if(m_verbosity >= 1) m_num_build_path_calls++;
  }

  void Graph::computeFlowForPath(const std::vector<Vertex> &current_path) {
    std::vector<Capacity> path_capacities;
    for(uint i = 0; i < current_path.size() - 1; i++) {
      path_capacities.push_back(m_capapcities.at(current_path.at(i).getID() - 1).at(current_path.at(i + 1).getID() - 1));
    }
    Capacity flow = *std::min_element(path_capacities.begin(), path_capacities.end());
    m_max_flow += flow;
    for(uint i = 0; i < current_path.size() - 1; i++) {
      m_capapcities.at(current_path.at(i).getID() - 1).at(current_path.at(i + 1).getID() - 1) -= flow;
      m_flow.at(current_path.at(i).getID() - 1).at(current_path.at(i + 1).getID() - 1) += flow;
    }
    if(m_verbosity >= 1) m_num_paths++;
    if(m_verbosity >= 2) {
      std::stringstream path;
      path << std::to_string(current_path.front().getID());
      for(uint i = 1; i < current_path.size(); i++) {
        path << " > " << current_path.at(i).getID();
      }
      path << " | flow = " << flow;
      m_augmenting_paths.push_back(path.str());
    }
  }

  void Graph::constructLevelGraph() {
    std::queue<Vertex> q;
    for(auto &v : m_vertices) {
      v.setLevel(UNDEF_LEVEL);
    }
    auto m_source = std::find_if(m_vertices.begin(), m_vertices.end(), [this] (const Vertex &v) { return (v.getID() == m_source_id); });
    m_source->setLevel(0);
    q.push(*m_source);
    while(!q.empty()) {
      Vertex current_vertex = q.front();
      int current_level = current_vertex.getLevel();
      q.pop();
      // restructure this to use matrix
      for(auto const &arc : current_vertex.getOutgoingArcs()) {
        if(m_capapcities.at(arc.start - 1).at(arc.end - 1) <= 0) continue;
        auto it = std::find_if(m_vertices.begin(), m_vertices.end(), [&arc] (const Vertex &v) { return (v.getID() == arc.end) && !v.hasDefinedLevel(); });
        if(it != m_vertices.end()) {
          it->setLevel(current_level + 1);
          q.push(*it);
        }
      }
    }
    if(m_verbosity >= 1) m_num_level_graphs_built++;
  }

  void Graph::hasUniqueMaxFlow() {
  *m_ofstream << "\nChecking uniqueness of maximum flow:\n";
    CapacityMatrix residualGraph = m_capapcities;
    for(uint row = 0; row < m_flow.size(); row++) {
      for(uint i = 0; i < m_flow.at(0).size(); i++) {
        residualGraph.at(i).at(row) += m_flow.at(row).at(i);
      }
    }

    int visited[m_num_vertices];
    std::stack<int> *recursive_stack = new std::stack<int>();
    bool found_cycle = false;
    for(uint i = 0; i < m_num_vertices; i++){
      visited[i] = NOT_PROCESSED;
    }

    for(uint i = 0; i < m_num_vertices; i++){
      if(visited[i] == NOT_PROCESSED) {
        visited[i] = ON_STACK;
        recursive_stack = new std::stack<int>();
        recursive_stack->push(i);
        isResidualGraphCyclic(residualGraph, visited, INVALID_VERTEX, recursive_stack, found_cycle);
      }
    }
    if(!found_cycle) *m_ofstream << "The max flow is unique!\n";
  }

  void Graph::isResidualGraphCyclic(const CapacityMatrix &residual_graph, int *visited, const int previous, std::stack<int> *recursive_stack, bool &found_cycle) {
    int top = recursive_stack->top();
    for(uint next = 0; next < residual_graph.at(top).size(); next++) {
      if(next == previous) continue;
      if(residual_graph.at(top).at(next) == 0) continue;
      //std::cout << "prev, top, next, capacity: " << previous +1 << ", " << top +1 << ", " << next+1 << ", " << residual_graph.at(top).at(next) <<std::endl;
      if(visited[next] == NOT_PROCESSED || visited[next] == PROCESSED) {
        recursive_stack->push(next);
        visited[next] = ON_STACK;
        isResidualGraphCyclic(residual_graph, visited, top, recursive_stack, found_cycle);
      } else if(visited[next] == ON_STACK) {
        if(recursive_stack->size() >= 3) {
          found_cycle = true;
          printCycle(residual_graph, recursive_stack);
        }
      }
    }
    visited[top] = PROCESSED;
    recursive_stack->pop();
  }

  void Graph::printCycle(const CapacityMatrix residual_matrix, std::stack<int> *stack) {
    int first, previous, current;

    std::stack<int> print_stack;
    while(stack->size() != 0) {
      print_stack.push(stack->top());
      stack->pop();
    }

    previous = print_stack.top();
    first = previous;
    print_stack.pop();
    *m_ofstream << previous + 1;
    int min_capacity = INT_MAX;
    while(print_stack.size() != 0) {
      current = print_stack.top();
      *m_ofstream << " -> " << current + 1;
      if(residual_matrix.at(previous).at(current) < min_capacity) min_capacity = residual_matrix.at(previous).at(current);
      print_stack.pop();
      stack->push(current);
      previous = current;
    }
    *m_ofstream << " -> " << first + 1;
    if(residual_matrix.at(current).at(first) < min_capacity) min_capacity = residual_matrix.at(current).at(first);
    *m_ofstream << ", where " << min_capacity << " unit";
    if(min_capacity > 1) *m_ofstream << "s";
    *m_ofstream << " of flow could be shifted." << std::endl;
  }

  void Graph::hasUniqueMinCut() {
    *m_ofstream << "\nChecking uniqueness of minimum cut:\n";
    std::vector<Arc> *cut_edges = new std::vector<Arc>();
    computeMinCut(nullptr, nullptr, cut_edges);
    bool found_another_min_cut = false;
    for(auto const &arc : *cut_edges) {
      Graph augmented_network = *this;
      augmented_network.incrementArcCapacity(arc.start - 1, arc.end - 1);
      augmented_network.resetNetwork();
      augmented_network.disableOutput();
      int augmented_max_flow = augmented_network.maxFlowDinic();
      augmented_network.enableOutput();
      if(augmented_max_flow == m_max_flow) {
        found_another_min_cut = true;
        *m_ofstream << "Found another minimum cut after incrementing (" << arc.start << "," << arc.end << ") with a flow value of " << augmented_max_flow << "." << std::endl;
        augmented_network.printMinCut();
        //augmented_network.hasUniqueMinCut();
      }
    }
    if(!found_another_min_cut) {
      *m_ofstream << "The minimum cut is unique!\n";
    }
  }


  void Graph::printInformation() const {
    auto m_source = std::find_if(m_vertices.begin(), m_vertices.end(), [this] (const Vertex &v) { return (v.getID() == m_source_id); });
    auto m_sink = std::find_if(m_vertices.begin(), m_vertices.end(), [this] (const Vertex &v) { return (v.getID() == m_sink_id); });
    *m_ofstream << "#Vertices: " << m_num_vertices << std::endl;
    *m_ofstream << "#Arc: " << m_num_arcs << std::endl;
    *m_ofstream << "Source: " << m_source->getID() << ", Sink: " << m_sink->getID() << std::endl;
    *m_ofstream << "Vertices: ";
    bool first = true;
    for(auto const& v : m_vertices) {
      if(first) first = false;
      else *m_ofstream << ", ";
      *m_ofstream << v.getID();
    }
    *m_ofstream << std::endl;
    for(auto const& a : m_arc_list) {
      *m_ofstream << "  " << a.start << " -> " << a.end << " capacity = " << a.capacity << std::endl;
    }
    *m_ofstream << std::endl;
  }

  void Graph::printMaxFlowInformation() const {
    *m_ofstream << "Max Flow per arc:\n";
    for(auto const &arc : m_arc_list) {
      *m_ofstream << "  " << arc.start << " -> " << arc.end << " flow = " << m_flow.at(arc.start - 1 ).at(arc.end - 1) << "/" << arc.capacity << "\n";
    }
  }

  void Graph::computeMinCut(std::vector<Vertex> *source_vertices, std::vector<Vertex> *sink_vertices, std::vector<Arc> *cut_edges) const {
    if(!source_vertices) source_vertices = new std::vector<Vertex>();
    if(!sink_vertices) sink_vertices = new std::vector<Vertex>();
    for(auto const &vertex : m_vertices) {
      if(vertex.getLevel() != UNDEF_LEVEL) {
        source_vertices->push_back(vertex.getID());
      } else {
        sink_vertices->push_back(vertex.getID());
      }
    }
    if(cut_edges) {
      for(auto const source : *source_vertices) {
        for(auto const sink : *sink_vertices) {
          int capacity = m_network.at(source.getID() - 1).at(sink.getID() - 1);
          if(capacity > 0) {
            cut_edges->push_back({source.getID(), sink.getID(), capacity, capacity, 0});
          }
        }
      }
    }
  }

  void Graph::printMinCut() const {
    std::vector<Arc> *arcs = new std::vector<Arc>();
    std::vector<Vertex> *source_vertices = new std::vector<Vertex>();
    std::vector<Vertex> *sink_vertices = new std::vector<Vertex>();
    computeMinCut(source_vertices, sink_vertices, arcs);
    std::vector<std::string> min_cut, complement;
    for(auto const &vertex : *source_vertices) {
      min_cut.push_back(std::to_string(vertex.getID()));
    }
    for(auto const &vertex : *sink_vertices) {
      complement.push_back(std::to_string(vertex.getID()));
    }

    *m_ofstream << "Min Cut X: {";
    bool first = true;
    for(auto const &v : min_cut) {
      if(first) first = false;
      else *m_ofstream << ", ";
      *m_ofstream << v;
    } *m_ofstream << "}\nComplement(X): {";
    first = true;
    for(auto const &v : complement) {
      if(first) first = false;
      else *m_ofstream << ", ";
      *m_ofstream << v;
    } *m_ofstream << "}\n";
  }

  void Graph::printComputationStatistics(const std::chrono::steady_clock::time_point &start, const std::chrono::steady_clock::time_point &end) const {
    *m_ofstream << "Elapsed time: " << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms (" << std::chrono::duration_cast<std::chrono::microseconds>(end - start).count() << "µs).\n";
    *m_ofstream << "Computation Statistics:\n";
    *m_ofstream << "  #level graphs built: " << m_num_level_graphs_built << "\n";
    *m_ofstream << "  #augmenting paths computed: " << m_num_paths << "\n";
    if(m_verbosity >= 2) {
      for(auto const &path : m_augmenting_paths) *m_ofstream << "    " << path << "\n";
    }
    *m_ofstream << "  #recursive buildPath calls: " << m_num_build_path_calls << "\n";
  }

  void Graph::incrementArcCapacity(const int source, const int sink) {
    m_network.at(source).at(sink)++;
  }

  void Graph::resetNetwork() {
    m_max_flow = 0;
    m_capapcities = m_network;
  }

  void Graph::disableOutput() {
    m_ofstream->setstate(std::ios_base::badbit);
  }

  void Graph::enableOutput() {
    m_ofstream->clear();
  }

  void Graph::printMatrices() const {
    //for(auto const row : m_flow) {
    //  for(auto const i : row) {
    //    std::cout << i << " ";
    //  } std::cout << std::endl;
    //}
    std::cout << std::endl;
    for(auto const row : m_network) {
      for(auto const i : row) {
        std::cout << i << " ";
      } std::cout << std::endl;
    }
    std::cout << std::endl;
    for(auto const row : m_capapcities) {
      for(auto const i : row) {
        std::cout << i << " ";
      } std::cout << std::endl;
    }
    std::cout << std::endl;
  }
}