You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.3 KiB

  1. #include <iostream>
  2. #include <Eigen/Core>
  3. #include <Eigen/Dense>
  4. #include <Eigen/IterativeLinearSolvers>
  5. #include <unsupported/Eigen/IterativeSolvers>
  6. class MatrixReplacement;
  7. using StormEigen::SparseMatrix;
  8. namespace StormEigen {
  9. namespace internal {
  10. // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits:
  11. template<>
  12. struct traits<MatrixReplacement> : public StormEigen::internal::traits<StormEigen::SparseMatrix<double> >
  13. {};
  14. }
  15. }
  16. // Example of a matrix-free wrapper from a user type to Eigen's compatible type
  17. // For the sake of simplicity, this example simply wrap a StormEigen::SparseMatrix.
  18. class MatrixReplacement : public StormEigen::EigenBase<MatrixReplacement> {
  19. public:
  20. // Required typedefs, constants, and method:
  21. typedef double Scalar;
  22. typedef double RealScalar;
  23. typedef int StorageIndex;
  24. enum {
  25. ColsAtCompileTime = StormEigen::Dynamic,
  26. MaxColsAtCompileTime = StormEigen::Dynamic,
  27. IsRowMajor = false
  28. };
  29. Index rows() const { return mp_mat->rows(); }
  30. Index cols() const { return mp_mat->cols(); }
  31. template<typename Rhs>
  32. StormEigen::Product<MatrixReplacement,Rhs,StormEigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const {
  33. return StormEigen::Product<MatrixReplacement,Rhs,StormEigen::AliasFreeProduct>(*this, x.derived());
  34. }
  35. // Custom API:
  36. MatrixReplacement() : mp_mat(0) {}
  37. void attachMyMatrix(const SparseMatrix<double> &mat) {
  38. mp_mat = &mat;
  39. }
  40. const SparseMatrix<double> my_matrix() const { return *mp_mat; }
  41. private:
  42. const SparseMatrix<double> *mp_mat;
  43. };
  44. // Implementation of MatrixReplacement * StormEigen::DenseVector though a specialization of internal::generic_product_impl:
  45. namespace StormEigen {
  46. namespace internal {
  47. template<typename Rhs>
  48. struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector
  49. : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> >
  50. {
  51. typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar;
  52. template<typename Dest>
  53. static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha)
  54. {
  55. // This method should implement "dst += alpha * lhs * rhs" inplace,
  56. // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it.
  57. assert(alpha==Scalar(1) && "scaling is not implemented");
  58. // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs,
  59. // but let's do something fancier (and less efficient):
  60. for(Index i=0; i<lhs.cols(); ++i)
  61. dst += rhs(i) * lhs.my_matrix().col(i);
  62. }
  63. };
  64. }
  65. }
  66. int main()
  67. {
  68. int n = 10;
  69. StormEigen::SparseMatrix<double> S = StormEigen::MatrixXd::Random(n,n).sparseView(0.5,1);
  70. S = S.transpose()*S;
  71. MatrixReplacement A;
  72. A.attachMyMatrix(S);
  73. StormEigen::VectorXd b(n), x;
  74. b.setRandom();
  75. // Solve Ax = b using various iterative solver with matrix-free version:
  76. {
  77. StormEigen::ConjugateGradient<MatrixReplacement, StormEigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg;
  78. cg.compute(A);
  79. x = cg.solve(b);
  80. std::cout << "CG: #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl;
  81. }
  82. {
  83. StormEigen::BiCGSTAB<MatrixReplacement, StormEigen::IdentityPreconditioner> bicg;
  84. bicg.compute(A);
  85. x = bicg.solve(b);
  86. std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl;
  87. }
  88. {
  89. StormEigen::GMRES<MatrixReplacement, StormEigen::IdentityPreconditioner> gmres;
  90. gmres.compute(A);
  91. x = gmres.solve(b);
  92. std::cout << "GMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
  93. }
  94. {
  95. StormEigen::DGMRES<MatrixReplacement, StormEigen::IdentityPreconditioner> gmres;
  96. gmres.compute(A);
  97. x = gmres.solve(b);
  98. std::cout << "DGMRES: #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
  99. }
  100. {
  101. StormEigen::MINRES<MatrixReplacement, StormEigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres;
  102. minres.compute(A);
  103. x = minres.solve(b);
  104. std::cout << "MINRES: #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl;
  105. }
  106. }