175 lines
6.0 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra. Eigen itself is part of the KDE project.
  3. //
  4. // Copyright (C) 2008 Gael Guennebaud <g.gael@free.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef EIGEN_GSL_HELPER
  10. #define EIGEN_GSL_HELPER
  11. #include <Eigen/Core>
  12. #include <gsl/gsl_blas.h>
  13. #include <gsl/gsl_multifit.h>
  14. #include <gsl/gsl_eigen.h>
  15. #include <gsl/gsl_linalg.h>
  16. #include <gsl/gsl_complex.h>
  17. #include <gsl/gsl_complex_math.h>
  18. namespace Eigen {
  19. template<typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex> struct GslTraits
  20. {
  21. typedef gsl_matrix* Matrix;
  22. typedef gsl_vector* Vector;
  23. static Matrix createMatrix(int rows, int cols) { return gsl_matrix_alloc(rows,cols); }
  24. static Vector createVector(int size) { return gsl_vector_alloc(size); }
  25. static void free(Matrix& m) { gsl_matrix_free(m); m=0; }
  26. static void free(Vector& m) { gsl_vector_free(m); m=0; }
  27. static void prod(const Matrix& m, const Vector& v, Vector& x) { gsl_blas_dgemv(CblasNoTrans,1,m,v,0,x); }
  28. static void cholesky(Matrix& m) { gsl_linalg_cholesky_decomp(m); }
  29. static void cholesky_solve(const Matrix& m, const Vector& b, Vector& x) { gsl_linalg_cholesky_solve(m,b,x); }
  30. static void eigen_symm(const Matrix& m, Vector& eval, Matrix& evec)
  31. {
  32. gsl_eigen_symmv_workspace * w = gsl_eigen_symmv_alloc(m->size1);
  33. Matrix a = createMatrix(m->size1, m->size2);
  34. gsl_matrix_memcpy(a, m);
  35. gsl_eigen_symmv(a,eval,evec,w);
  36. gsl_eigen_symmv_sort(eval, evec, GSL_EIGEN_SORT_VAL_ASC);
  37. gsl_eigen_symmv_free(w);
  38. free(a);
  39. }
  40. static void eigen_symm_gen(const Matrix& m, const Matrix& _b, Vector& eval, Matrix& evec)
  41. {
  42. gsl_eigen_gensymmv_workspace * w = gsl_eigen_gensymmv_alloc(m->size1);
  43. Matrix a = createMatrix(m->size1, m->size2);
  44. Matrix b = createMatrix(_b->size1, _b->size2);
  45. gsl_matrix_memcpy(a, m);
  46. gsl_matrix_memcpy(b, _b);
  47. gsl_eigen_gensymmv(a,b,eval,evec,w);
  48. gsl_eigen_symmv_sort(eval, evec, GSL_EIGEN_SORT_VAL_ASC);
  49. gsl_eigen_gensymmv_free(w);
  50. free(a);
  51. }
  52. };
  53. template<typename Scalar> struct GslTraits<Scalar,true>
  54. {
  55. typedef gsl_matrix_complex* Matrix;
  56. typedef gsl_vector_complex* Vector;
  57. static Matrix createMatrix(int rows, int cols) { return gsl_matrix_complex_alloc(rows,cols); }
  58. static Vector createVector(int size) { return gsl_vector_complex_alloc(size); }
  59. static void free(Matrix& m) { gsl_matrix_complex_free(m); m=0; }
  60. static void free(Vector& m) { gsl_vector_complex_free(m); m=0; }
  61. static void cholesky(Matrix& m) { gsl_linalg_complex_cholesky_decomp(m); }
  62. static void cholesky_solve(const Matrix& m, const Vector& b, Vector& x) { gsl_linalg_complex_cholesky_solve(m,b,x); }
  63. static void prod(const Matrix& m, const Vector& v, Vector& x)
  64. { gsl_blas_zgemv(CblasNoTrans,gsl_complex_rect(1,0),m,v,gsl_complex_rect(0,0),x); }
  65. static void eigen_symm(const Matrix& m, gsl_vector* &eval, Matrix& evec)
  66. {
  67. gsl_eigen_hermv_workspace * w = gsl_eigen_hermv_alloc(m->size1);
  68. Matrix a = createMatrix(m->size1, m->size2);
  69. gsl_matrix_complex_memcpy(a, m);
  70. gsl_eigen_hermv(a,eval,evec,w);
  71. gsl_eigen_hermv_sort(eval, evec, GSL_EIGEN_SORT_VAL_ASC);
  72. gsl_eigen_hermv_free(w);
  73. free(a);
  74. }
  75. static void eigen_symm_gen(const Matrix& m, const Matrix& _b, gsl_vector* &eval, Matrix& evec)
  76. {
  77. gsl_eigen_genhermv_workspace * w = gsl_eigen_genhermv_alloc(m->size1);
  78. Matrix a = createMatrix(m->size1, m->size2);
  79. Matrix b = createMatrix(_b->size1, _b->size2);
  80. gsl_matrix_complex_memcpy(a, m);
  81. gsl_matrix_complex_memcpy(b, _b);
  82. gsl_eigen_genhermv(a,b,eval,evec,w);
  83. gsl_eigen_hermv_sort(eval, evec, GSL_EIGEN_SORT_VAL_ASC);
  84. gsl_eigen_genhermv_free(w);
  85. free(a);
  86. }
  87. };
  88. template<typename MatrixType>
  89. void convert(const MatrixType& m, gsl_matrix* &res)
  90. {
  91. // if (res)
  92. // gsl_matrix_free(res);
  93. res = gsl_matrix_alloc(m.rows(), m.cols());
  94. for (int i=0 ; i<m.rows() ; ++i)
  95. for (int j=0 ; j<m.cols(); ++j)
  96. gsl_matrix_set(res, i, j, m(i,j));
  97. }
  98. template<typename MatrixType>
  99. void convert(const gsl_matrix* m, MatrixType& res)
  100. {
  101. res.resize(int(m->size1), int(m->size2));
  102. for (int i=0 ; i<res.rows() ; ++i)
  103. for (int j=0 ; j<res.cols(); ++j)
  104. res(i,j) = gsl_matrix_get(m,i,j);
  105. }
  106. template<typename VectorType>
  107. void convert(const VectorType& m, gsl_vector* &res)
  108. {
  109. if (res) gsl_vector_free(res);
  110. res = gsl_vector_alloc(m.size());
  111. for (int i=0 ; i<m.size() ; ++i)
  112. gsl_vector_set(res, i, m[i]);
  113. }
  114. template<typename VectorType>
  115. void convert(const gsl_vector* m, VectorType& res)
  116. {
  117. res.resize (m->size);
  118. for (int i=0 ; i<res.rows() ; ++i)
  119. res[i] = gsl_vector_get(m, i);
  120. }
  121. template<typename MatrixType>
  122. void convert(const MatrixType& m, gsl_matrix_complex* &res)
  123. {
  124. res = gsl_matrix_complex_alloc(m.rows(), m.cols());
  125. for (int i=0 ; i<m.rows() ; ++i)
  126. for (int j=0 ; j<m.cols(); ++j)
  127. {
  128. gsl_matrix_complex_set(res, i, j,
  129. gsl_complex_rect(m(i,j).real(), m(i,j).imag()));
  130. }
  131. }
  132. template<typename MatrixType>
  133. void convert(const gsl_matrix_complex* m, MatrixType& res)
  134. {
  135. res.resize(int(m->size1), int(m->size2));
  136. for (int i=0 ; i<res.rows() ; ++i)
  137. for (int j=0 ; j<res.cols(); ++j)
  138. res(i,j) = typename MatrixType::Scalar(
  139. GSL_REAL(gsl_matrix_complex_get(m,i,j)),
  140. GSL_IMAG(gsl_matrix_complex_get(m,i,j)));
  141. }
  142. template<typename VectorType>
  143. void convert(const VectorType& m, gsl_vector_complex* &res)
  144. {
  145. res = gsl_vector_complex_alloc(m.size());
  146. for (int i=0 ; i<m.size() ; ++i)
  147. gsl_vector_complex_set(res, i, gsl_complex_rect(m[i].real(), m[i].imag()));
  148. }
  149. template<typename VectorType>
  150. void convert(const gsl_vector_complex* m, VectorType& res)
  151. {
  152. res.resize(m->size);
  153. for (int i=0 ; i<res.rows() ; ++i)
  154. res[i] = typename VectorType::Scalar(
  155. GSL_REAL(gsl_vector_complex_get(m, i)),
  156. GSL_IMAG(gsl_vector_complex_get(m, i)));
  157. }
  158. }
  159. #endif // EIGEN_GSL_HELPER