You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

634 lines
34 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #include "common.h"
  10. int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  11. {
  12. // std::cerr << "in gemm " << *opa << " " << *opb << " " << *m << " " << *n << " " << *k << " " << *lda << " " << *ldb << " " << *ldc << " " << *palpha << " " << *pbeta << "\n";
  13. typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&, Eigen::internal::GemmParallelInfo<DenseIndex>*);
  14. static functype func[12];
  15. static bool init = false;
  16. if(!init)
  17. {
  18. for(int k=0; k<12; ++k)
  19. func[k] = 0;
  20. func[NOTR | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,ColMajor,false,ColMajor>::run);
  21. func[TR | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,false,ColMajor>::run);
  22. func[ADJ | (NOTR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor>::run);
  23. func[NOTR | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,false,ColMajor>::run);
  24. func[TR | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,false,ColMajor>::run);
  25. func[ADJ | (TR << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,false,ColMajor>::run);
  26. func[NOTR | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor>::run);
  27. func[TR | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,false,Scalar,RowMajor,Conj, ColMajor>::run);
  28. func[ADJ | (ADJ << 2)] = (internal::general_matrix_matrix_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,RowMajor,Conj, ColMajor>::run);
  29. init = true;
  30. }
  31. Scalar* a = reinterpret_cast<Scalar*>(pa);
  32. Scalar* b = reinterpret_cast<Scalar*>(pb);
  33. Scalar* c = reinterpret_cast<Scalar*>(pc);
  34. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  35. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  36. int info = 0;
  37. if(OP(*opa)==INVALID) info = 1;
  38. else if(OP(*opb)==INVALID) info = 2;
  39. else if(*m<0) info = 3;
  40. else if(*n<0) info = 4;
  41. else if(*k<0) info = 5;
  42. else if(*lda<std::max(1,(OP(*opa)==NOTR)?*m:*k)) info = 8;
  43. else if(*ldb<std::max(1,(OP(*opb)==NOTR)?*k:*n)) info = 10;
  44. else if(*ldc<std::max(1,*m)) info = 13;
  45. if(info)
  46. return xerbla_(SCALAR_SUFFIX_UP"GEMM ",&info,6);
  47. if(beta!=Scalar(1))
  48. {
  49. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  50. else matrix(c, *m, *n, *ldc) *= beta;
  51. }
  52. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
  53. int code = OP(*opa) | (OP(*opb) << 2);
  54. func[code](*m, *n, *k, a, *lda, b, *ldb, c, *ldc, alpha, blocking, 0);
  55. return 0;
  56. }
  57. int EIGEN_BLAS_FUNC(trsm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
  58. {
  59. // std::cerr << "in trsm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << "," << *n << " " << *palpha << " " << *lda << " " << *ldb<< "\n";
  60. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, internal::level3_blocking<Scalar,Scalar>&);
  61. static functype func[32];
  62. static bool init = false;
  63. if(!init)
  64. {
  65. for(int k=0; k<32; ++k)
  66. func[k] = 0;
  67. func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,ColMajor,ColMajor>::run);
  68. func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,RowMajor,ColMajor>::run);
  69. func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, Conj, RowMajor,ColMajor>::run);
  70. func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,ColMajor,ColMajor>::run);
  71. func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,RowMajor,ColMajor>::run);
  72. func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, Conj, RowMajor,ColMajor>::run);
  73. func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|0, false,ColMajor,ColMajor>::run);
  74. func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, false,RowMajor,ColMajor>::run);
  75. func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|0, Conj, RowMajor,ColMajor>::run);
  76. func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|0, false,ColMajor,ColMajor>::run);
  77. func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, false,RowMajor,ColMajor>::run);
  78. func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|0, Conj, RowMajor,ColMajor>::run);
  79. func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,ColMajor,ColMajor>::run);
  80. func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,RowMajor,ColMajor>::run);
  81. func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
  82. func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,ColMajor,ColMajor>::run);
  83. func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,RowMajor,ColMajor>::run);
  84. func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,Conj, RowMajor,ColMajor>::run);
  85. func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Lower|UnitDiag,false,ColMajor,ColMajor>::run);
  86. func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,false,RowMajor,ColMajor>::run);
  87. func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheLeft, Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
  88. func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Lower|UnitDiag,false,ColMajor,ColMajor>::run);
  89. func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,false,RowMajor,ColMajor>::run);
  90. func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::triangular_solve_matrix<Scalar,DenseIndex,OnTheRight,Upper|UnitDiag,Conj, RowMajor,ColMajor>::run);
  91. init = true;
  92. }
  93. Scalar* a = reinterpret_cast<Scalar*>(pa);
  94. Scalar* b = reinterpret_cast<Scalar*>(pb);
  95. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  96. int info = 0;
  97. if(SIDE(*side)==INVALID) info = 1;
  98. else if(UPLO(*uplo)==INVALID) info = 2;
  99. else if(OP(*opa)==INVALID) info = 3;
  100. else if(DIAG(*diag)==INVALID) info = 4;
  101. else if(*m<0) info = 5;
  102. else if(*n<0) info = 6;
  103. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
  104. else if(*ldb<std::max(1,*m)) info = 11;
  105. if(info)
  106. return xerbla_(SCALAR_SUFFIX_UP"TRSM ",&info,6);
  107. int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
  108. if(SIDE(*side)==LEFT)
  109. {
  110. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
  111. func[code](*m, *n, a, *lda, b, *ldb, blocking);
  112. }
  113. else
  114. {
  115. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
  116. func[code](*n, *m, a, *lda, b, *ldb, blocking);
  117. }
  118. if(alpha!=Scalar(1))
  119. matrix(b,*m,*n,*ldb) *= alpha;
  120. return 0;
  121. }
  122. // b = alpha*op(a)*b for side = 'L'or'l'
  123. // b = alpha*b*op(a) for side = 'R'or'r'
  124. int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb)
  125. {
  126. // std::cerr << "in trmm " << *side << " " << *uplo << " " << *opa << " " << *diag << " " << *m << " " << *n << " " << *lda << " " << *ldb << " " << *palpha << "\n";
  127. typedef void (*functype)(DenseIndex, DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar, internal::level3_blocking<Scalar,Scalar>&);
  128. static functype func[32];
  129. static bool init = false;
  130. if(!init)
  131. {
  132. for(int k=0; k<32; ++k)
  133. func[k] = 0;
  134. func[NOTR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
  135. func[TR | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
  136. func[ADJ | (LEFT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  137. func[NOTR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
  138. func[TR | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
  139. func[ADJ | (RIGHT << 2) | (UP << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  140. func[NOTR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, true, ColMajor,false,ColMajor,false,ColMajor>::run);
  141. func[TR | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,false,ColMajor,false,ColMajor>::run);
  142. func[ADJ | (LEFT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  143. func[NOTR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|0, false,ColMajor,false,ColMajor,false,ColMajor>::run);
  144. func[TR | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,false,ColMajor>::run);
  145. func[ADJ | (RIGHT << 2) | (LO << 3) | (NUNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|0, false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  146. func[NOTR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
  147. func[TR | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
  148. func[ADJ | (LEFT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  149. func[NOTR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
  150. func[TR | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
  151. func[ADJ | (RIGHT << 2) | (UP << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  152. func[NOTR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,true, ColMajor,false,ColMajor,false,ColMajor>::run);
  153. func[TR | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,false,ColMajor,false,ColMajor>::run);
  154. func[ADJ | (LEFT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,true, RowMajor,Conj, ColMajor,false,ColMajor>::run);
  155. func[NOTR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Lower|UnitDiag,false,ColMajor,false,ColMajor,false,ColMajor>::run);
  156. func[TR | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,false,ColMajor>::run);
  157. func[ADJ | (RIGHT << 2) | (LO << 3) | (UNIT << 4)] = (internal::product_triangular_matrix_matrix<Scalar,DenseIndex,Upper|UnitDiag,false,ColMajor,false,RowMajor,Conj, ColMajor>::run);
  158. init = true;
  159. }
  160. Scalar* a = reinterpret_cast<Scalar*>(pa);
  161. Scalar* b = reinterpret_cast<Scalar*>(pb);
  162. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  163. int info = 0;
  164. if(SIDE(*side)==INVALID) info = 1;
  165. else if(UPLO(*uplo)==INVALID) info = 2;
  166. else if(OP(*opa)==INVALID) info = 3;
  167. else if(DIAG(*diag)==INVALID) info = 4;
  168. else if(*m<0) info = 5;
  169. else if(*n<0) info = 6;
  170. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 9;
  171. else if(*ldb<std::max(1,*m)) info = 11;
  172. if(info)
  173. return xerbla_(SCALAR_SUFFIX_UP"TRMM ",&info,6);
  174. int code = OP(*opa) | (SIDE(*side) << 2) | (UPLO(*uplo) << 3) | (DIAG(*diag) << 4);
  175. if(*m==0 || *n==0)
  176. return 1;
  177. // FIXME find a way to avoid this copy
  178. Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb);
  179. matrix(b,*m,*n,*ldb).setZero();
  180. if(SIDE(*side)==LEFT)
  181. {
  182. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*m);
  183. func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha, blocking);
  184. }
  185. else
  186. {
  187. internal::gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic,4> blocking(*m,*n,*n);
  188. func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha, blocking);
  189. }
  190. return 1;
  191. }
  192. // c = alpha*a*b + beta*c for side = 'L'or'l'
  193. // c = alpha*b*a + beta*c for side = 'R'or'r
  194. int EIGEN_BLAS_FUNC(symm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  195. {
  196. // std::cerr << "in symm " << *side << " " << *uplo << " " << *m << "x" << *n << " lda:" << *lda << " ldb:" << *ldb << " ldc:" << *ldc << " alpha:" << *palpha << " beta:" << *pbeta << "\n";
  197. Scalar* a = reinterpret_cast<Scalar*>(pa);
  198. Scalar* b = reinterpret_cast<Scalar*>(pb);
  199. Scalar* c = reinterpret_cast<Scalar*>(pc);
  200. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  201. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  202. int info = 0;
  203. if(SIDE(*side)==INVALID) info = 1;
  204. else if(UPLO(*uplo)==INVALID) info = 2;
  205. else if(*m<0) info = 3;
  206. else if(*n<0) info = 4;
  207. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
  208. else if(*ldb<std::max(1,*m)) info = 9;
  209. else if(*ldc<std::max(1,*m)) info = 12;
  210. if(info)
  211. return xerbla_(SCALAR_SUFFIX_UP"SYMM ",&info,6);
  212. if(beta!=Scalar(1))
  213. {
  214. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  215. else matrix(c, *m, *n, *ldc) *= beta;
  216. }
  217. if(*m==0 || *n==0)
  218. {
  219. return 1;
  220. }
  221. #if ISCOMPLEX
  222. // FIXME add support for symmetric complex matrix
  223. int size = (SIDE(*side)==LEFT) ? (*m) : (*n);
  224. Matrix<Scalar,Dynamic,Dynamic,ColMajor> matA(size,size);
  225. if(UPLO(*uplo)==UP)
  226. {
  227. matA.triangularView<Upper>() = matrix(a,size,size,*lda);
  228. matA.triangularView<Lower>() = matrix(a,size,size,*lda).transpose();
  229. }
  230. else if(UPLO(*uplo)==LO)
  231. {
  232. matA.triangularView<Lower>() = matrix(a,size,size,*lda);
  233. matA.triangularView<Upper>() = matrix(a,size,size,*lda).transpose();
  234. }
  235. if(SIDE(*side)==LEFT)
  236. matrix(c, *m, *n, *ldc) += alpha * matA * matrix(b, *m, *n, *ldb);
  237. else if(SIDE(*side)==RIGHT)
  238. matrix(c, *m, *n, *ldc) += alpha * matrix(b, *m, *n, *ldb) * matA;
  239. #else
  240. if(SIDE(*side)==LEFT)
  241. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, RowMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  242. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,true,false, ColMajor,false,false, ColMajor>::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  243. else return 0;
  244. else if(SIDE(*side)==RIGHT)
  245. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, RowMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  246. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar, DenseIndex, ColMajor,false,false, ColMajor,true,false, ColMajor>::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  247. else return 0;
  248. else
  249. return 0;
  250. #endif
  251. return 0;
  252. }
  253. // c = alpha*a*a' + beta*c for op = 'N'or'n'
  254. // c = alpha*a'*a + beta*c for op = 'T'or't','C'or'c'
  255. int EIGEN_BLAS_FUNC(syrk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
  256. {
  257. // std::cerr << "in syrk " << *uplo << " " << *op << " " << *n << " " << *k << " " << *palpha << " " << *lda << " " << *pbeta << " " << *ldc << "\n";
  258. #if !ISCOMPLEX
  259. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar);
  260. static functype func[8];
  261. static bool init = false;
  262. if(!init)
  263. {
  264. for(int k=0; k<8; ++k)
  265. func[k] = 0;
  266. func[NOTR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Upper>::run);
  267. func[TR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Upper>::run);
  268. func[ADJ | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Upper>::run);
  269. func[NOTR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,ColMajor,Conj, Lower>::run);
  270. func[TR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,false,Scalar,ColMajor,ColMajor,Conj, Lower>::run);
  271. func[ADJ | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,ColMajor,false,Lower>::run);
  272. init = true;
  273. }
  274. #endif
  275. Scalar* a = reinterpret_cast<Scalar*>(pa);
  276. Scalar* c = reinterpret_cast<Scalar*>(pc);
  277. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  278. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  279. int info = 0;
  280. if(UPLO(*uplo)==INVALID) info = 1;
  281. else if(OP(*op)==INVALID) info = 2;
  282. else if(*n<0) info = 3;
  283. else if(*k<0) info = 4;
  284. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  285. else if(*ldc<std::max(1,*n)) info = 10;
  286. if(info)
  287. return xerbla_(SCALAR_SUFFIX_UP"SYRK ",&info,6);
  288. if(beta!=Scalar(1))
  289. {
  290. if(UPLO(*uplo)==UP)
  291. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  292. else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
  293. else
  294. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  295. else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
  296. }
  297. #if ISCOMPLEX
  298. // FIXME add support for symmetric complex matrix
  299. if(UPLO(*uplo)==UP)
  300. {
  301. if(OP(*op)==NOTR)
  302. matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
  303. else
  304. matrix(c, *n, *n, *ldc).triangularView<Upper>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
  305. }
  306. else
  307. {
  308. if(OP(*op)==NOTR)
  309. matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*n,*k,*lda) * matrix(a,*n,*k,*lda).transpose();
  310. else
  311. matrix(c, *n, *n, *ldc).triangularView<Lower>() += alpha * matrix(a,*k,*n,*lda).transpose() * matrix(a,*k,*n,*lda);
  312. }
  313. #else
  314. int code = OP(*op) | (UPLO(*uplo) << 2);
  315. func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha);
  316. #endif
  317. return 0;
  318. }
  319. // c = alpha*a*b' + alpha*b*a' + beta*c for op = 'N'or'n'
  320. // c = alpha*a'*b + alpha*b'*a + beta*c for op = 'T'or't'
  321. int EIGEN_BLAS_FUNC(syr2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  322. {
  323. Scalar* a = reinterpret_cast<Scalar*>(pa);
  324. Scalar* b = reinterpret_cast<Scalar*>(pb);
  325. Scalar* c = reinterpret_cast<Scalar*>(pc);
  326. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  327. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  328. int info = 0;
  329. if(UPLO(*uplo)==INVALID) info = 1;
  330. else if(OP(*op)==INVALID) info = 2;
  331. else if(*n<0) info = 3;
  332. else if(*k<0) info = 4;
  333. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  334. else if(*ldb<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
  335. else if(*ldc<std::max(1,*n)) info = 12;
  336. if(info)
  337. return xerbla_(SCALAR_SUFFIX_UP"SYR2K",&info,6);
  338. if(beta!=Scalar(1))
  339. {
  340. if(UPLO(*uplo)==UP)
  341. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  342. else matrix(c, *n, *n, *ldc).triangularView<Upper>() *= beta;
  343. else
  344. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  345. else matrix(c, *n, *n, *ldc).triangularView<Lower>() *= beta;
  346. }
  347. if(*k==0)
  348. return 1;
  349. if(OP(*op)==NOTR)
  350. {
  351. if(UPLO(*uplo)==UP)
  352. {
  353. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  354. += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
  355. + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
  356. }
  357. else if(UPLO(*uplo)==LO)
  358. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  359. += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).transpose()
  360. + alpha*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).transpose();
  361. }
  362. else if(OP(*op)==TR || OP(*op)==ADJ)
  363. {
  364. if(UPLO(*uplo)==UP)
  365. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  366. += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
  367. + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
  368. else if(UPLO(*uplo)==LO)
  369. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  370. += alpha*matrix(a, *k, *n, *lda).transpose()*matrix(b, *k, *n, *ldb)
  371. + alpha*matrix(b, *k, *n, *ldb).transpose()*matrix(a, *k, *n, *lda);
  372. }
  373. return 0;
  374. }
  375. #if ISCOMPLEX
  376. // c = alpha*a*b + beta*c for side = 'L'or'l'
  377. // c = alpha*b*a + beta*c for side = 'R'or'r
  378. int EIGEN_BLAS_FUNC(hemm)(char *side, char *uplo, int *m, int *n, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  379. {
  380. Scalar* a = reinterpret_cast<Scalar*>(pa);
  381. Scalar* b = reinterpret_cast<Scalar*>(pb);
  382. Scalar* c = reinterpret_cast<Scalar*>(pc);
  383. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  384. Scalar beta = *reinterpret_cast<Scalar*>(pbeta);
  385. // std::cerr << "in hemm " << *side << " " << *uplo << " " << *m << " " << *n << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
  386. int info = 0;
  387. if(SIDE(*side)==INVALID) info = 1;
  388. else if(UPLO(*uplo)==INVALID) info = 2;
  389. else if(*m<0) info = 3;
  390. else if(*n<0) info = 4;
  391. else if(*lda<std::max(1,(SIDE(*side)==LEFT)?*m:*n)) info = 7;
  392. else if(*ldb<std::max(1,*m)) info = 9;
  393. else if(*ldc<std::max(1,*m)) info = 12;
  394. if(info)
  395. return xerbla_(SCALAR_SUFFIX_UP"HEMM ",&info,6);
  396. if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
  397. else if(beta!=Scalar(1)) matrix(c, *m, *n, *ldc) *= beta;
  398. if(*m==0 || *n==0)
  399. {
  400. return 1;
  401. }
  402. if(SIDE(*side)==LEFT)
  403. {
  404. if(UPLO(*uplo)==UP) internal::product_selfadjoint_matrix<Scalar,DenseIndex,RowMajor,true,Conj, ColMajor,false,false, ColMajor>
  405. ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  406. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,true,false, ColMajor,false,false, ColMajor>
  407. ::run(*m, *n, a, *lda, b, *ldb, c, *ldc, alpha);
  408. else return 0;
  409. }
  410. else if(SIDE(*side)==RIGHT)
  411. {
  412. if(UPLO(*uplo)==UP) matrix(c,*m,*n,*ldc) += alpha * matrix(b,*m,*n,*ldb) * matrix(a,*n,*n,*lda).selfadjointView<Upper>();/*internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, RowMajor,true,Conj, ColMajor>
  413. ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);*/
  414. else if(UPLO(*uplo)==LO) internal::product_selfadjoint_matrix<Scalar,DenseIndex,ColMajor,false,false, ColMajor,true,false, ColMajor>
  415. ::run(*m, *n, b, *ldb, a, *lda, c, *ldc, alpha);
  416. else return 0;
  417. }
  418. else
  419. {
  420. return 0;
  421. }
  422. return 0;
  423. }
  424. // c = alpha*a*conj(a') + beta*c for op = 'N'or'n'
  425. // c = alpha*conj(a')*a + beta*c for op = 'C'or'c'
  426. int EIGEN_BLAS_FUNC(herk)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pbeta, RealScalar *pc, int *ldc)
  427. {
  428. typedef void (*functype)(DenseIndex, DenseIndex, const Scalar *, DenseIndex, const Scalar *, DenseIndex, Scalar *, DenseIndex, Scalar);
  429. static functype func[8];
  430. static bool init = false;
  431. if(!init)
  432. {
  433. for(int k=0; k<8; ++k)
  434. func[k] = 0;
  435. func[NOTR | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Upper>::run);
  436. func[ADJ | (UP << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Upper>::run);
  437. func[NOTR | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,ColMajor,false,Scalar,RowMajor,Conj, ColMajor,Lower>::run);
  438. func[ADJ | (LO << 2)] = (internal::general_matrix_matrix_triangular_product<DenseIndex,Scalar,RowMajor,Conj, Scalar,ColMajor,false,ColMajor,Lower>::run);
  439. init = true;
  440. }
  441. Scalar* a = reinterpret_cast<Scalar*>(pa);
  442. Scalar* c = reinterpret_cast<Scalar*>(pc);
  443. RealScalar alpha = *palpha;
  444. RealScalar beta = *pbeta;
  445. // std::cerr << "in herk " << *uplo << " " << *op << " " << *n << " " << *k << " " << alpha << " " << *lda << " " << beta << " " << *ldc << "\n";
  446. int info = 0;
  447. if(UPLO(*uplo)==INVALID) info = 1;
  448. else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
  449. else if(*n<0) info = 3;
  450. else if(*k<0) info = 4;
  451. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  452. else if(*ldc<std::max(1,*n)) info = 10;
  453. if(info)
  454. return xerbla_(SCALAR_SUFFIX_UP"HERK ",&info,6);
  455. int code = OP(*op) | (UPLO(*uplo) << 2);
  456. if(beta!=RealScalar(1))
  457. {
  458. if(UPLO(*uplo)==UP)
  459. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  460. else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
  461. else
  462. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  463. else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
  464. if(beta!=Scalar(0))
  465. {
  466. matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
  467. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  468. }
  469. }
  470. if(*k>0 && alpha!=RealScalar(0))
  471. {
  472. func[code](*n, *k, a, *lda, a, *lda, c, *ldc, alpha);
  473. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  474. }
  475. return 0;
  476. }
  477. // c = alpha*a*conj(b') + conj(alpha)*b*conj(a') + beta*c, for op = 'N'or'n'
  478. // c = alpha*conj(a')*b + conj(alpha)*conj(b')*a + beta*c, for op = 'C'or'c'
  479. int EIGEN_BLAS_FUNC(her2k)(char *uplo, char *op, int *n, int *k, RealScalar *palpha, RealScalar *pa, int *lda, RealScalar *pb, int *ldb, RealScalar *pbeta, RealScalar *pc, int *ldc)
  480. {
  481. Scalar* a = reinterpret_cast<Scalar*>(pa);
  482. Scalar* b = reinterpret_cast<Scalar*>(pb);
  483. Scalar* c = reinterpret_cast<Scalar*>(pc);
  484. Scalar alpha = *reinterpret_cast<Scalar*>(palpha);
  485. RealScalar beta = *pbeta;
  486. int info = 0;
  487. if(UPLO(*uplo)==INVALID) info = 1;
  488. else if((OP(*op)==INVALID) || (OP(*op)==TR)) info = 2;
  489. else if(*n<0) info = 3;
  490. else if(*k<0) info = 4;
  491. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 7;
  492. else if(*lda<std::max(1,(OP(*op)==NOTR)?*n:*k)) info = 9;
  493. else if(*ldc<std::max(1,*n)) info = 12;
  494. if(info)
  495. return xerbla_(SCALAR_SUFFIX_UP"HER2K",&info,6);
  496. if(beta!=RealScalar(1))
  497. {
  498. if(UPLO(*uplo)==UP)
  499. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Upper>().setZero();
  500. else matrix(c, *n, *n, *ldc).triangularView<StrictlyUpper>() *= beta;
  501. else
  502. if(beta==Scalar(0)) matrix(c, *n, *n, *ldc).triangularView<Lower>().setZero();
  503. else matrix(c, *n, *n, *ldc).triangularView<StrictlyLower>() *= beta;
  504. if(beta!=Scalar(0))
  505. {
  506. matrix(c, *n, *n, *ldc).diagonal().real() *= beta;
  507. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  508. }
  509. }
  510. else if(*k>0 && alpha!=Scalar(0))
  511. matrix(c, *n, *n, *ldc).diagonal().imag().setZero();
  512. if(*k==0)
  513. return 1;
  514. if(OP(*op)==NOTR)
  515. {
  516. if(UPLO(*uplo)==UP)
  517. {
  518. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  519. += alpha *matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
  520. + internal::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
  521. }
  522. else if(UPLO(*uplo)==LO)
  523. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  524. += alpha*matrix(a, *n, *k, *lda)*matrix(b, *n, *k, *ldb).adjoint()
  525. + internal::conj(alpha)*matrix(b, *n, *k, *ldb)*matrix(a, *n, *k, *lda).adjoint();
  526. }
  527. else if(OP(*op)==ADJ)
  528. {
  529. if(UPLO(*uplo)==UP)
  530. matrix(c, *n, *n, *ldc).triangularView<Upper>()
  531. += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
  532. + internal::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
  533. else if(UPLO(*uplo)==LO)
  534. matrix(c, *n, *n, *ldc).triangularView<Lower>()
  535. += alpha*matrix(a, *k, *n, *lda).adjoint()*matrix(b, *k, *n, *ldb)
  536. + internal::conj(alpha)*matrix(b, *k, *n, *ldb).adjoint()*matrix(a, *k, *n, *lda);
  537. }
  538. return 1;
  539. }
  540. #endif // ISCOMPLEX