You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

632 lines
34 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "common.h"
  10. int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  11. {
  12. // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
  13. typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
  14. static functype func[12];
  15. static bool init = false;
  16. if(!init)
  17. {
  18. for(int k=0; k<12; ++k)
  19. func[k] = 0;
  20. func[NOTR | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run);
  21. func[TR | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run);
  22. func[ADJ | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run);
  23. func[NOTR | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run);
  24. func[TR | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run);
  25. func[ADJ | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run);
  26. func[NOTR | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run);
  27. func[TR | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run);
  28. func[ADJ | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run);
  29. init = true;
  30. }
  31. Scalar* a = reinterpret_cast<Scalar*>(pa);
  32. Scalar* b = reinterpret_cast<Scalar*>(pb);
  33. Scalar* c = reinterpret_cast<Scalar*>(pc);
  34. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  35. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  36. int info = 0;
  37. if(OP(*opa)==INVALID) info = 1;
  38. else if(OP(*opb)==INVALID) info = 2;
  39. else if(*m<0) info = 3;
  40. else if(*n<0) info = 4;
  41. else if(*k<0) info = 5;
  42. else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
  43. else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
  44. else if(*ldc<std::max(1,*m)) info = 13;
  45. if(info)
  46. return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
  47. if(beta!=Scalar(1))
  48. {
  49. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  50. else matrix(c, *m, *n, *ldc) *= beta;
  51. }
  52. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
  53. int code = OP(*opa) | (OP(*opb) << 2);
  54. func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
  55. return 0;
  56. }
  57. int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
  58. {
  59. // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
  60. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
  61. static functype func[32];
  62. static bool init = false;
  63. if(!init)
  64. {
  65. for(int k=0; k<32; ++k)
  66. func[k] = 0;
  67. func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run);
  68. func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run);
  69. func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run);
  70. func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run);
  71. func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run);
  72. func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run);
  73. func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run);
  74. func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run);
  75. func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run);
  76. func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run);
  77. func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run);
  78. func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run);
  79. func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run);
  80. func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run);
  81. func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
  82. func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run);
  83. func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run);
  84. func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
  85. func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run);
  86. func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run);
  87. func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
  88. func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run);
  89. func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run);
  90. func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
  91. init = true;
  92. }
  93. Scalar* a = reinterpret_cast<Scalar*>(pa);
  94. Scalar* b = reinterpret_cast<Scalar*>(pb);
  95. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  96. int info = 0;
  97. if(SIDE(*side)==INVALID) info = 1;
  98. else if(UPLO(*uplo)==INVALID) info = 2;
  99. else if(OP(*opa)==INVALID) info = 3;
  100. else if(DIAG(*diag)==INVALID) info = 4;
  101. else if(*m<0) info = 5;
  102. else if(*n<0) info = 6;
  103. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
  104. else if(*ldb<std::max(1,*m)) info = 11;
  105. if(info)
  106. return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
  107. int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
  108. if(SIDE(*side)==LEFT)
  109. {
  110. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
  111. func[code](*m, *n, a, *lda, b, *ldb, blocking);
  112. }
  113. else
  114. {
  115. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
  116. func[code](*n, *m, a, *lda, b, *ldb, blocking);
  117. }
  118. if(alpha!=Scalar(1))
  119. matrix(b,*m,*n,*ldb) *= alpha;
  120. return 0;
  121. }
  122. // b = alpha*op(a)*b for side = 'L'or'l'
  123. // b = alpha*b*op(a) for side = 'R'or'r'
  124. int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
  125. {
  126. // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
  127. typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&);
  128. static functype func[32];
  129. static bool init = false;
  130. if(!init)
  131. {
  132. for(int k=0; k<32; ++k)
  133. func[k] = 0;
  134. func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
  135. func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
  136. func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  137. func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
  138. func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
  139. func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  140. func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
  141. func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
  142. func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  143. func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
  144. func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
  145. func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  146. func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
  147. func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
  148. func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  149. func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
  150. func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
  151. func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  152. func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
  153. func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
  154. func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  155. func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
  156. func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
  157. func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  158. init = true;
  159. }
  160. Scalar* a = reinterpret_cast<Scalar*>(pa);
  161. Scalar* b = reinterpret_cast<Scalar*>(pb);
  162. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  163. int info = 0;
  164. if(SIDE(*side)==INVALID) info = 1;
  165. else if(UPLO(*uplo)==INVALID) info = 2;
  166. else if(OP(*opa)==INVALID) info = 3;
  167. else if(DIAG(*diag)==INVALID) info = 4;
  168. else if(*m<0) info = 5;
  169. else if(*n<0) info = 6;
  170. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
  171. else if(*ldb<std::max(1,*m)) info = 11;
  172. if(info)
  173. return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
  174. int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
  175. if(*m==0 || *n==0)
  176. return 1;
  177. // FIXME find a way to avoid this copy
  178. Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb);
  179. matrix(b,*m,*n,*ldb).setZero();
  180. if(SIDE(*side)==LEFT)
  181. {
  182. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
  183. func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
  184. }
  185. else
  186. {
  187. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
  188. func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
  189. }
  190. return 1;
  191. }
  192. // c = alpha*a*b + beta*c for side = 'L'or'l'
  193. // c = alpha*b*a + beta*c for side = 'R'or'r
  194. int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  195. {
  196. // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
  197. Scalar* a = reinterpret_cast<Scalar*>(pa);
  198. Scalar* b = reinterpret_cast<Scalar*>(pb);
  199. Scalar* c = reinterpret_cast<Scalar*>(pc);
  200. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  201. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  202. int info = 0;
  203. if(SIDE(*side)==INVALID) info = 1;
  204. else if(UPLO(*uplo)==INVALID) info = 2;
  205. else if(*m<0) info = 3;
  206. else if(*n<0) info = 4;
  207. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
  208. else if(*ldb<std::max(1,*m)) info = 9;
  209. else if(*ldc<std::max(1,*m)) info = 12;
  210. if(info)
  211. return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
  212. if(beta!=Scalar(1))
  213. {
  214. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  215. else matrix(c, *m, *n, *ldc) *= beta;
  216. }
  217. if(*m==0 || *n==0)
  218. {
  219. return 1;
  220. }
  221. #if ISCOMPLEX
  222. // FIXME add support for symmetric complex matrix
  223. int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
  224. Matrix<Scalar,Dynamic,Dynamic,ColMajor> matA(size,size);
  225. if(UPLO(*uplo)==UP)
  226. {
  227. matA.triangularView<Upper>() = matrix(a,size,size,*lda);
  228. matA.triangularView<Lower>() = matrix(a,size,size,*lda).transpose();
  229. }
  230. else if(UPLO(*uplo)==LO)
  231. {
  232. matA.triangularView<Lower>() = matrix(a,size,size,*lda);
  233. matA.triangularView<Upper>() = matrix(a,size,size,*lda).transpose();
  234. }
  235. if(SIDE(*side)==LEFT)
  236. matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb);
  237. else if(SIDE(*side)==RIGHT)
  238. matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA;
  239. #else
  240. if(SIDE(*side)==LEFT)
  241. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  242. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  243. else return 0;
  244. else if(SIDE(*side)==RIGHT)
  245. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  246. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  247. else return 0;
  248. else
  249. return 0;
  250. #endif
  251. return 0;
  252. }
  253. // c = alpha*a*a' + beta*c for op = 'N'or'n'
  254. // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
  255. int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
  256. {
  257. // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
  258. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar);
  259. static functype func[8];
  260. static bool init = false;
  261. if(!init)
  262. {
  263. for(int k=0; k<8; ++k)
  264. func[k] = 0;
  265. func[NOTR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Upper>::run);
  266. func[TR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Upper>::run);
  267. func[ADJ | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Upper>::run);
  268. func[NOTR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Lower>::run);
  269. func[TR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Lower>::run);
  270. func[ADJ | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Lower>::run);
  271. init = true;
  272. }
  273. Scalar* a = reinterpret_cast<Scalar*>(pa);
  274. Scalar* c = reinterpret_cast<Scalar*>(pc);
  275. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  276. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  277. int info = 0;
  278. if(UPLO(*uplo)==INVALID) info = 1;
  279. else if(OP(*op)==INVALID) info = 2;
  280. else if(*n<0) info = 3;
  281. else if(*k<0) info = 4;
  282. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  283. else if(*ldc<std::max(1,*n)) info = 10;
  284. if(info)
  285. return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
  286. if(beta!=Scalar(1))
  287. {
  288. if(UPLO(*uplo)==UP)
  289. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  290. else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
  291. else
  292. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  293. else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
  294. }
  295. #if ISCOMPLEX
  296. // FIXME add support for symmetric complex matrix
  297. if(UPLO(*uplo)==UP)
  298. {
  299. if(OP(*op)==NOTR)
  300. matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
  301. else
  302. matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
  303. }
  304. else
  305. {
  306. if(OP(*op)==NOTR)
  307. matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
  308. else
  309. matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
  310. }
  311. #else
  312. int code = OP(*op) | (UPLO(*uplo) << 2);
  313. func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha);
  314. #endif
  315. return 0;
  316. }
  317. // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
  318. // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
  319. int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  320. {
  321. Scalar* a = reinterpret_cast<Scalar*>(pa);
  322. Scalar* b = reinterpret_cast<Scalar*>(pb);
  323. Scalar* c = reinterpret_cast<Scalar*>(pc);
  324. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  325. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  326. int info = 0;
  327. if(UPLO(*uplo)==INVALID) info = 1;
  328. else if(OP(*op)==INVALID) info = 2;
  329. else if(*n<0) info = 3;
  330. else if(*k<0) info = 4;
  331. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  332. else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
  333. else if(*ldc<std::max(1,*n)) info = 12;
  334. if(info)
  335. return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
  336. if(beta!=Scalar(1))
  337. {
  338. if(UPLO(*uplo)==UP)
  339. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  340. else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
  341. else
  342. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  343. else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
  344. }
  345. if(*k==0)
  346. return 1;
  347. if(OP(*op)==NOTR)
  348. {
  349. if(UPLO(*uplo)==UP)
  350. {
  351. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  352. += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
  353. + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
  354. }
  355. else if(UPLO(*uplo)==LO)
  356. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  357. += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
  358. + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
  359. }
  360. else if(OP(*op)==TR || OP(*op)==ADJ)
  361. {
  362. if(UPLO(*uplo)==UP)
  363. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  364. += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
  365. + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
  366. else if(UPLO(*uplo)==LO)
  367. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  368. += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
  369. + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
  370. }
  371. return 0;
  372. }
  373. #if ISCOMPLEX
  374. // c = alpha*a*b + beta*c for side = 'L'or'l'
  375. // c = alpha*b*a + beta*c for side = 'R'or'r
  376. int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  377. {
  378. Scalar* a = reinterpret_cast<Scalar*>(pa);
  379. Scalar* b = reinterpret_cast<Scalar*>(pb);
  380. Scalar* c = reinterpret_cast<Scalar*>(pc);
  381. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  382. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  383. // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
  384. int info = 0;
  385. if(SIDE(*side)==INVALID) info = 1;
  386. else if(UPLO(*uplo)==INVALID) info = 2;
  387. else if(*m<0) info = 3;
  388. else if(*n<0) info = 4;
  389. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
  390. else if(*ldb<std::max(1,*m)) info = 9;
  391. else if(*ldc<std::max(1,*m)) info = 12;
  392. if(info)
  393. return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
  394. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  395. else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta;
  396. if(*m==0 || *n==0)
  397. {
  398. return 1;
  399. }
  400. if(SIDE(*side)==LEFT)
  401. {
  402. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor>
  403. ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  404. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor>
  405. ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  406. else return 0;
  407. }
  408. else if(SIDE(*side)==RIGHT)
  409. {
  410. if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
  411. ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);*/
  412. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor>
  413. ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  414. else return 0;
  415. }
  416. else
  417. {
  418. return 0;
  419. }
  420. return 0;
  421. }
  422. // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
  423. // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
  424. int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
  425. {
  426. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar);
  427. static functype func[8];
  428. static bool init = false;
  429. if(!init)
  430. {
  431. for(int k=0; k<8; ++k)
  432. func[k] = 0;
  433. func[NOTR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Upper>::run);
  434. func[ADJ | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Upper>::run);
  435. func[NOTR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Lower>::run);
  436. func[ADJ | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Lower>::run);
  437. init = true;
  438. }
  439. Scalar* a = reinterpret_cast<Scalar*>(pa);
  440. Scalar* c = reinterpret_cast<Scalar*>(pc);
  441. RealScalar alpha = *palpha;
  442. RealScalar beta = *pbeta;
  443. // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
  444. int info = 0;
  445. if(UPLO(*uplo)==INVALID) info = 1;
  446. else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
  447. else if(*n<0) info = 3;
  448. else if(*k<0) info = 4;
  449. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  450. else if(*ldc<std::max(1,*n)) info = 10;
  451. if(info)
  452. return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
  453. int code = OP(*op) | (UPLO(*uplo) << 2);
  454. if(beta!=RealScalar(1))
  455. {
  456. if(UPLO(*uplo)==UP)
  457. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  458. else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
  459. else
  460. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  461. else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
  462. if(beta!=Scalar(0))
  463. {
  464. matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
  465. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  466. }
  467. }
  468. if(*k>0 && alpha!=RealScalar(0))
  469. {
  470. func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha);
  471. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  472. }
  473. return 0;
  474. }
  475. // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
  476. // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
  477. int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  478. {
  479. Scalar* a = reinterpret_cast<Scalar*>(pa);
  480. Scalar* b = reinterpret_cast<Scalar*>(pb);
  481. Scalar* c = reinterpret_cast<Scalar*>(pc);
  482. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  483. RealScalar beta = *pbeta;
  484. int info = 0;
  485. if(UPLO(*uplo)==INVALID) info = 1;
  486. else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
  487. else if(*n<0) info = 3;
  488. else if(*k<0) info = 4;
  489. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  490. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
  491. else if(*ldc<std::max(1,*n)) info = 12;
  492. if(info)
  493. return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
  494. if(beta!=RealScalar(1))
  495. {
  496. if(UPLO(*uplo)==UP)
  497. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  498. else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
  499. else
  500. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  501. else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
  502. if(beta!=Scalar(0))
  503. {
  504. matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
  505. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  506. }
  507. }
  508. else if(*k>0 && alpha!=Scalar(0))
  509. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  510. if(*k==0)
  511. return 1;
  512. if(OP(*op)==NOTR)
  513. {
  514. if(UPLO(*uplo)==UP)
  515. {
  516. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  517. += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
  518. + internal::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
  519. }
  520. else if(UPLO(*uplo)==LO)
  521. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  522. += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
  523. + internal::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
  524. }
  525. else if(OP(*op)==ADJ)
  526. {
  527. if(UPLO(*uplo)==UP)
  528. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  529. += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
  530. + internal::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
  531. else if(UPLO(*uplo)==LO)
  532. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  533. += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
  534. + internal::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
  535. }
  536. return 1;
  537. }
  538. #endif // ISCOMPLEX