#include "src/solver/TopologicalValueIterationNondeterministicLinearEquationSolver.h"

#include <utility>

#include "src/settings/Settings.h"
#include "src/utility/vector.h"
#include "src/utility/graph.h"
#include "src/models/PseudoModel.h"
#include "src/storage/StronglyConnectedComponentDecomposition.h"
#include "src/exceptions/IllegalArgumentException.h"
#include "src/exceptions/InvalidStateException.h"

#include "log4cplus/logger.h"
#include "log4cplus/loggingmacros.h"
extern log4cplus::Logger logger;

#include "storm-config.h"
#include "cudaForStorm.h"

namespace storm {
    namespace solver {
        
        template<typename ValueType>
		TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>::TopologicalValueIterationNondeterministicLinearEquationSolver() {
			// Get the settings object to customize solving.
			storm::settings::Settings* settings = storm::settings::Settings::getInstance();

			// Get appropriate settings.
			this->maximalNumberOfIterations = settings->getOptionByLongName("maxiter").getArgument(0).getValueAsUnsignedInteger();
			this->precision = settings->getOptionByLongName("precision").getArgument(0).getValueAsDouble();
			this->relative = !settings->isSet("absolute");
        }
        
        template<typename ValueType>
		TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>::TopologicalValueIterationNondeterministicLinearEquationSolver(double precision, uint_fast64_t maximalNumberOfIterations, bool relative) : NativeNondeterministicLinearEquationSolver<ValueType>(precision, maximalNumberOfIterations, relative) {
            // Intentionally left empty.
        }
        
        template<typename ValueType>
		NondeterministicLinearEquationSolver<ValueType>* TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>::clone() const {
			return new TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>(*this);
        }
        
        template<typename ValueType>
		void TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>::solveEquationSystem(bool minimize, storm::storage::SparseMatrix<ValueType> const& A, std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<ValueType>* multiplyResult, std::vector<ValueType>* newX) const {
            
			// Now, we need to determine the SCCs of the MDP and a topological sort.
			//std::vector<std::vector<uint_fast64_t>> stronglyConnectedComponents = storm::utility::graph::performSccDecomposition(this->getModel(), stronglyConnectedComponents, stronglyConnectedComponentsDependencyGraph);
			//storm::storage::SparseMatrix<T> stronglyConnectedComponentsDependencyGraph = this->getModel().extractSccDependencyGraph(stronglyConnectedComponents);
			std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = A.getRowGroupIndices();
			storm::models::NonDeterministicMatrixBasedPseudoModel<ValueType> pseudoModel(A, nondeterministicChoiceIndices);
			//storm::storage::StronglyConnectedComponentDecomposition<ValueType> sccDecomposition(*static_cast<storm::models::AbstractPseudoModel<ValueType>*>(&pseudoModel), false, false);
			storm::storage::StronglyConnectedComponentDecomposition<ValueType> sccDecomposition(pseudoModel, false, false);

			if (sccDecomposition.size() == 0) {
				LOG4CPLUS_ERROR(logger, "Can not solve given Equation System as the SCC Decomposition returned no SCCs.");
				throw storm::exceptions::IllegalArgumentException() << "Can not solve given Equation System as the SCC Decomposition returned no SCCs.";
			}

			storm::storage::SparseMatrix<ValueType> stronglyConnectedComponentsDependencyGraph = pseudoModel.extractPartitionDependencyGraph(sccDecomposition);
			std::vector<uint_fast64_t> topologicalSort = storm::utility::graph::getTopologicalSort(stronglyConnectedComponentsDependencyGraph);

			// Calculate the optimal distribution of sccs
			std::vector<std::pair<bool, std::vector<uint_fast64_t>>> optimalSccs = this->getOptimalGroupingFromTopologicalSccDecomposition(sccDecomposition, topologicalSort, A);

			// Set up the environment for the power method.
//			bool multiplyResultMemoryProvided = true;
//			if (multiplyResult == nullptr) {
//				multiplyResult = new std::vector<ValueType>(A.getRowCount());
//				multiplyResultMemoryProvided = false;
//			}
			std::vector<ValueType>* currentX = nullptr;
			//bool xMemoryProvided = true;
			//if (newX == nullptr) {
			//	newX = new std::vector<ValueType>(x.size());
			//	xMemoryProvided = false;
			//}
			std::vector<ValueType>* swap = nullptr;
			uint_fast64_t currentMaxLocalIterations = 0;
			uint_fast64_t localIterations = 0;
			uint_fast64_t globalIterations = 0;
			bool converged = true;

			// Iterate over all SCCs of the MDP as specified by the topological sort. This guarantees that an SCC is only
			// solved after all SCCs it depends on have been solved.
			int counter = 0;

			for (auto sccIndexIt = optimalSccs.cbegin(); sccIndexIt != optimalSccs.cend() && converged; ++sccIndexIt) {
				bool const useGpu = sccIndexIt->first;
				std::vector <uint_fast64_t> const& scc = sccIndexIt->second;

				// Generate a submatrix
				storm::storage::BitVector subMatrixIndices(A.getColumnCount(), scc.cbegin(), scc.cend());
				storm::storage::SparseMatrix<ValueType> sccSubmatrix = A.getSubmatrix(true, subMatrixIndices, subMatrixIndices);
				std::vector<ValueType> sccSubB(sccSubmatrix.getRowCount());
				storm::utility::vector::selectVectorValues<ValueType>(sccSubB, subMatrixIndices, nondeterministicChoiceIndices, b);
				std::vector<ValueType> sccSubX(sccSubmatrix.getColumnCount());
				std::vector<ValueType> sccSubXSwap(sccSubmatrix.getColumnCount());
                std::vector<ValueType> sccMultiplyResult(sccSubmatrix.getRowCount());
                
				// Prepare the pointers for swapping in the calculation
				currentX = &sccSubX;
				swap = &sccSubXSwap;

				storm::utility::vector::selectVectorValues<ValueType>(sccSubX, subMatrixIndices, x); // x is getCols() large, where as b and multiplyResult are getRows() (nondet. choices times states)
				std::vector<uint_fast64_t> sccSubNondeterministicChoiceIndices(sccSubmatrix.getColumnCount() + 1);
				sccSubNondeterministicChoiceIndices.at(0) = 0;

				// Preprocess all dependant states
				// Remove outgoing transitions and create the ChoiceIndices
				uint_fast64_t innerIndex = 0;
                uint_fast64_t outerIndex = 0;
				for (uint_fast64_t state: scc) {
					// Choice Indices
					sccSubNondeterministicChoiceIndices.at(outerIndex + 1) = sccSubNondeterministicChoiceIndices.at(outerIndex) + (nondeterministicChoiceIndices[state + 1] - nondeterministicChoiceIndices[state]);

					for (auto rowGroupIt = nondeterministicChoiceIndices[state]; rowGroupIt != nondeterministicChoiceIndices[state + 1]; ++rowGroupIt) {
						typename storm::storage::SparseMatrix<ValueType>::const_rows row = A.getRow(rowGroupIt);
						for (auto rowIt = row.begin(); rowIt != row.end(); ++rowIt) {
							if (!subMatrixIndices.get(rowIt->first)) {
								// This is an outgoing transition of a state in the SCC to a state not included in the SCC
								// Subtracting Pr(tau) * x_other from b fixes that
								sccSubB.at(innerIndex) = sccSubB.at(innerIndex) + (rowIt->second * x.at(rowIt->first));
							}
						}
                        ++innerIndex;
					}
                    ++outerIndex;
				}

				// For the current SCC, we need to perform value iteration until convergence.
				if (useGpu) {
#ifdef STORM_HAVE_CUDAFORSTORM
					if (!resetCudaDevice()) {
						LOG4CPLUS_ERROR(logger, "Could not reset CUDA Device, can not use CUDA Equation Solver.");
						throw storm::exceptions::InvalidStateException() << "Could not reset CUDA Device, can not use CUDA Equation Solver.";
					}

					LOG4CPLUS_INFO(logger, "Device has " << getTotalCudaMemory() << " Bytes of Memory with " << getFreeCudaMemory() << "Bytes free (" << (static_cast<double>(getFreeCudaMemory()) / static_cast<double>(getTotalCudaMemory())) * 100 << "%).");
					LOG4CPLUS_INFO(logger, "We will allocate " << (sizeof(uint_fast64_t)* sccSubmatrix.rowIndications.size() + sizeof(uint_fast64_t)* sccSubmatrix.columnsAndValues.size() * 2 + sizeof(double)* sccSubX.size() + sizeof(double)* sccSubX.size() + sizeof(double)* sccSubB.size() + sizeof(double)* sccSubB.size() + sizeof(uint_fast64_t)* sccSubNondeterministicChoiceIndices.size()) << " Bytes.");
					LOG4CPLUS_INFO(logger, "The CUDA Runtime Version is " << getRuntimeCudaVersion());

					std::vector<ValueType> copyX(*currentX);
					if (minimize) {
						basicValueIteration_mvReduce_uint64_double_minimize(this->maximalNumberOfIterations, this->precision, this->relative, sccSubmatrix.rowIndications, sccSubmatrix.columnsAndValues, copyX, sccSubB, sccSubNondeterministicChoiceIndices);
					}
					else {
						basicValueIteration_mvReduce_uint64_double_maximize(this->maximalNumberOfIterations, this->precision, this->relative, sccSubmatrix.rowIndications, sccSubmatrix.columnsAndValues, copyX, sccSubB, sccSubNondeterministicChoiceIndices);
					}
					converged = true;

					// DEBUG
					localIterations = 0;
					converged = false;
					while (!converged && localIterations < this->maximalNumberOfIterations) {
						// Compute x' = A*x + b.
						sccSubmatrix.multiplyWithVector(*currentX, sccMultiplyResult);
						storm::utility::vector::addVectorsInPlace<ValueType>(sccMultiplyResult, sccSubB);

						//A.multiplyWithVector(scc, nondeterministicChoiceIndices, *currentX, multiplyResult);
						//storm::utility::addVectors(scc, nondeterministicChoiceIndices, multiplyResult, b);

						/*
						Versus:
						A.multiplyWithVector(*currentX, *multiplyResult);
						storm::utility::vector::addVectorsInPlace(*multiplyResult, b);
						*/

						// Reduce the vector x' by applying min/max for all non-deterministic choices.
						if (minimize) {
							storm::utility::vector::reduceVectorMin<ValueType>(sccMultiplyResult, *swap, sccSubNondeterministicChoiceIndices);
						}
						else {
							storm::utility::vector::reduceVectorMax<ValueType>(sccMultiplyResult, *swap, sccSubNondeterministicChoiceIndices);
						}

						// Determine whether the method converged.
						// TODO: It seems that the equalModuloPrecision call that compares all values should have a higher
						// running time. In fact, it is faster. This has to be investigated.
						// converged = storm::utility::equalModuloPrecision(*currentX, *newX, scc, precision, relative);
						converged = storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *swap, this->precision, this->relative);

						// Update environment variables.
						std::swap(currentX, swap);

						++localIterations;
						++globalIterations;
					}
					LOG4CPLUS_INFO(logger, "Executed " << localIterations << " of max. " << maximalNumberOfIterations << " Iterations.");

					uint_fast64_t diffCount = 0;
					for (size_t i = 0; i < currentX->size(); ++i) {
						if (currentX->at(i) != copyX.at(i)) {
							LOG4CPLUS_WARN(logger, "CUDA solution differs on index " << i << " diff. " << std::abs(currentX->at(i) - copyX.at(i)) << ", CPU: " << currentX->at(i) << ", CUDA: " << copyX.at(i));
							std::cout << "CUDA solution differs on index " << i << " diff. " << std::abs(currentX->at(i) - copyX.at(i)) << ", CPU: " << currentX->at(i) << ", CUDA: " << copyX.at(i) << std::endl;
							++diffCount;
						}
					}
					std::cout << "CUDA solution differed in " << diffCount << " of " << currentX->size() << " values." << std::endl;
#endif
				} else {
					localIterations = 0;
					converged = false;
					while (!converged && localIterations < this->maximalNumberOfIterations) {
						// Compute x' = A*x + b.
						sccSubmatrix.multiplyWithVector(*currentX, sccMultiplyResult);
						storm::utility::vector::addVectorsInPlace<ValueType>(sccMultiplyResult, sccSubB);

						//A.multiplyWithVector(scc, nondeterministicChoiceIndices, *currentX, multiplyResult);
						//storm::utility::addVectors(scc, nondeterministicChoiceIndices, multiplyResult, b);

						/*
						Versus:
						A.multiplyWithVector(*currentX, *multiplyResult);
						storm::utility::vector::addVectorsInPlace(*multiplyResult, b);
						*/

						// Reduce the vector x' by applying min/max for all non-deterministic choices.
						if (minimize) {
							storm::utility::vector::reduceVectorMin<ValueType>(sccMultiplyResult, *swap, sccSubNondeterministicChoiceIndices);
						}
						else {
							storm::utility::vector::reduceVectorMax<ValueType>(sccMultiplyResult, *swap, sccSubNondeterministicChoiceIndices);
						}

						// Determine whether the method converged.
						// TODO: It seems that the equalModuloPrecision call that compares all values should have a higher
						// running time. In fact, it is faster. This has to be investigated.
						// converged = storm::utility::equalModuloPrecision(*currentX, *newX, scc, precision, relative);
						converged = storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *swap, this->precision, this->relative);

						// Update environment variables.
						std::swap(currentX, swap);

						++localIterations;
						++globalIterations;
					}
					LOG4CPLUS_INFO(logger, "Executed " << localIterations << " of max. " << maximalNumberOfIterations << " Iterations.");
				}


				// The Result of this SCC has to be taken back into the main result vector
				innerIndex = 0;
				for (uint_fast64_t state: scc) {
					x.at(state) = currentX->at(innerIndex);
					++innerIndex;
				}

				// Since the pointers for swapping in the calculation point to temps they should not be valid anymore
				currentX = nullptr;
				swap = nullptr;

				// As the "number of iterations" of the full method is the maximum of the local iterations, we need to keep
				// track of the maximum.
				if (localIterations > currentMaxLocalIterations) {
					currentMaxLocalIterations = localIterations;
				}
			}
			
			//if (!xMemoryProvided) {
			//	delete newX;
			//}

//			if (!multiplyResultMemoryProvided) {
//				delete multiplyResult;
//			}

			// Check if the solver converged and issue a warning otherwise.
			if (converged) {
				LOG4CPLUS_INFO(logger, "Iterative solver converged after " << currentMaxLocalIterations << " iterations.");
			}
			else {
				LOG4CPLUS_WARN(logger, "Iterative solver did not converged after " << currentMaxLocalIterations << " iterations.");
			}
        }

		template<typename ValueType>
		std::vector<std::pair<bool, std::vector<uint_fast64_t>>> 
			TopologicalValueIterationNondeterministicLinearEquationSolver<ValueType>::getOptimalGroupingFromTopologicalSccDecomposition(storm::storage::StronglyConnectedComponentDecomposition<ValueType> const& sccDecomposition, std::vector<uint_fast64_t> const& topologicalSort, storm::storage::SparseMatrix<ValueType> const& matrix) const {
				std::vector<std::pair<bool, std::vector<uint_fast64_t>>> result;
#ifdef STORM_HAVE_CUDAFORSTORM
				// 95% to have a bit of padding
				size_t const cudaFreeMemory = static_cast<size_t>(getFreeCudaMemory() * 0.95);
				size_t lastResultIndex = 0;

				std::vector<uint_fast64_t> const& rowGroupIndices = matrix.getRowGroupIndices();
				size_t currentSize = 0;
				for (auto sccIndexIt = topologicalSort.cbegin(); sccIndexIt != topologicalSort.cend(); ++sccIndexIt) {
					storm::storage::StateBlock const& scc = sccDecomposition[*sccIndexIt];

					uint_fast64_t rowCount = 0;
					uint_fast64_t entryCount = 0;
					std::vector<uint_fast64_t> rowGroups;
					rowGroups.reserve(scc.size());

					for (auto sccIt = scc.cbegin(); sccIt != scc.cend(); ++sccIt) {
						rowCount += matrix.getRowGroupSize(*sccIt);
						entryCount += matrix.getRowGroupEntryCount(*sccIt);
						rowGroups.push_back(*sccIt);
					}

					size_t sccSize = basicValueIteration_mvReduce_uint64_double_calculateMemorySize(static_cast<size_t>(rowCount), scc.size(), static_cast<size_t>(entryCount));

					if ((currentSize + sccSize) <= cudaFreeMemory) {
						// There is enough space left in the current group

						if (currentSize == 0) {
							result.push_back(std::make_pair(true, rowGroups));
						}
						else {
							result[lastResultIndex].second.insert(result[lastResultIndex].second.end(), rowGroups.begin(), rowGroups.end());
						}
						currentSize += sccSize;
					}
					else {
						if (sccSize <= cudaFreeMemory) {
							++lastResultIndex;
							result.push_back(std::make_pair(true, rowGroups));
							currentSize = sccSize;
						}
						else {
							// This group is too big to fit into the CUDA Memory by itself
							lastResultIndex += 2;
							result.push_back(std::make_pair(false, rowGroups));
							currentSize = 0;
						}
					}
				}
#else
				for (auto sccIndexIt = topologicalSort.cbegin(); sccIndexIt != topologicalSort.cend(); ++sccIndexIt) {
					storm::storage::StateBlock const& scc = sccDecomposition[*sccIndexIt];
					std::vector<uint_fast64_t> rowGroups;
					rowGroups.reserve(scc.size());
					for (auto sccIt = scc.cbegin(); sccIt != scc.cend(); ++sccIt) {
						rowGroups.push_back(*sccIt);
						result.push_back(std::make_pair(false, rowGroups));
					}
				}
#endif
			return result;
		}

        // Explicitly instantiate the solver.
		template class TopologicalValueIterationNondeterministicLinearEquationSolver<double>;
    } // namespace solver
} // namespace storm