You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

942 lines
48 KiB

  1. /* -*- c++ -*- (enables emacs c++ mode) */
  2. /*===========================================================================
  3. Copyright (C) 2003-2015 Yves Renard
  4. This file is a part of GETFEM++
  5. Getfem++ is free software; you can redistribute it and/or modify it
  6. under the terms of the GNU Lesser General Public License as published
  7. by the Free Software Foundation; either version 3 of the License, or
  8. (at your option) any later version along with the GCC Runtime Library
  9. Exception either version 3.1 or (at your option) any later version.
  10. This program is distributed in the hope that it will be useful, but
  11. WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
  12. or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
  13. License and GCC Runtime Library Exception for more details.
  14. You should have received a copy of the GNU Lesser General Public License
  15. along with this program; if not, write to the Free Software Foundation,
  16. Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
  17. As a special exception, you may use this file as it is a part of a free
  18. software library without restriction. Specifically, if other files
  19. instantiate templates or use macros or inline functions from this file,
  20. or you compile this file and link it with other files to produce an
  21. executable, this file does not by itself cause the resulting executable
  22. to be covered by the GNU Lesser General Public License. This exception
  23. does not however invalidate any other reasons why the executable file
  24. might be covered by the GNU Lesser General Public License.
  25. ===========================================================================*/
  26. /**@file gmm_blas_interface.h
  27. @author Yves Renard <Yves.Renard@insa-lyon.fr>
  28. @date October 7, 2003.
  29. @brief gmm interface for fortran BLAS.
  30. */
  31. #if defined(GETFEM_USES_BLAS) || defined(GMM_USES_BLAS) \
  32. || defined(GMM_USES_LAPACK) || defined(GMM_USES_ATLAS)
  33. #ifndef GMM_BLAS_INTERFACE_H
  34. #define GMM_BLAS_INTERFACE_H
  35. #include "gmm_blas.h"
  36. #include "gmm_interface.h"
  37. #include "gmm_matrix.h"
  38. namespace gmm {
  39. #define GMMLAPACK_TRACE(f)
  40. // #define GMMLAPACK_TRACE(f) cout << "function " << f << " called" << endl;
  41. /* ********************************************************************* */
  42. /* Operations interfaced for T = float, double, std::complex<float> */
  43. /* or std::complex<double> : */
  44. /* */
  45. /* vect_norm2(std::vector<T>) */
  46. /* */
  47. /* vect_sp(std::vector<T>, std::vector<T>) */
  48. /* vect_sp(scaled(std::vector<T>), std::vector<T>) */
  49. /* vect_sp(std::vector<T>, scaled(std::vector<T>)) */
  50. /* vect_sp(scaled(std::vector<T>), scaled(std::vector<T>)) */
  51. /* */
  52. /* vect_hp(std::vector<T>, std::vector<T>) */
  53. /* vect_hp(scaled(std::vector<T>), std::vector<T>) */
  54. /* vect_hp(std::vector<T>, scaled(std::vector<T>)) */
  55. /* vect_hp(scaled(std::vector<T>), scaled(std::vector<T>)) */
  56. /* */
  57. /* add(std::vector<T>, std::vector<T>) */
  58. /* add(scaled(std::vector<T>, a), std::vector<T>) */
  59. /* */
  60. /* mult(dense_matrix<T>, dense_matrix<T>, dense_matrix<T>) */
  61. /* mult(transposed(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
  62. /* mult(dense_matrix<T>, transposed(dense_matrix<T>), dense_matrix<T>) */
  63. /* mult(transposed(dense_matrix<T>), transposed(dense_matrix<T>), */
  64. /* dense_matrix<T>) */
  65. /* mult(conjugated(dense_matrix<T>), dense_matrix<T>, dense_matrix<T>) */
  66. /* mult(dense_matrix<T>, conjugated(dense_matrix<T>), dense_matrix<T>) */
  67. /* mult(conjugated(dense_matrix<T>), conjugated(dense_matrix<T>), */
  68. /* dense_matrix<T>) */
  69. /* */
  70. /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>) */
  71. /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
  72. /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
  73. /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
  74. /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
  75. /* std::vector<T>) */
  76. /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
  77. /* std::vector<T>) */
  78. /* */
  79. /* mult_add(dense_matrix<T>, std::vector<T>, std::vector<T>) */
  80. /* mult_add(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>) */
  81. /* mult_add(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>) */
  82. /* mult_add(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>) */
  83. /* mult_add(transposed(dense_matrix<T>), scaled(std::vector<T>), */
  84. /* std::vector<T>) */
  85. /* mult_add(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
  86. /* std::vector<T>) */
  87. /* */
  88. /* mult(dense_matrix<T>, std::vector<T>, std::vector<T>, std::vector<T>) */
  89. /* mult(transposed(dense_matrix<T>), std::vector<T>, std::vector<T>, */
  90. /* std::vector<T>) */
  91. /* mult(conjugated(dense_matrix<T>), std::vector<T>, std::vector<T>, */
  92. /* std::vector<T>) */
  93. /* mult(dense_matrix<T>, scaled(std::vector<T>), std::vector<T>, */
  94. /* std::vector<T>) */
  95. /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
  96. /* std::vector<T>, std::vector<T>) */
  97. /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
  98. /* std::vector<T>, std::vector<T>) */
  99. /* mult(dense_matrix<T>, std::vector<T>, scaled(std::vector<T>), */
  100. /* std::vector<T>) */
  101. /* mult(transposed(dense_matrix<T>), std::vector<T>, */
  102. /* scaled(std::vector<T>), std::vector<T>) */
  103. /* mult(conjugated(dense_matrix<T>), std::vector<T>, */
  104. /* scaled(std::vector<T>), std::vector<T>) */
  105. /* mult(dense_matrix<T>, scaled(std::vector<T>), scaled(std::vector<T>), */
  106. /* std::vector<T>) */
  107. /* mult(transposed(dense_matrix<T>), scaled(std::vector<T>), */
  108. /* scaled(std::vector<T>), std::vector<T>) */
  109. /* mult(conjugated(dense_matrix<T>), scaled(std::vector<T>), */
  110. /* scaled(std::vector<T>), std::vector<T>) */
  111. /* */
  112. /* lower_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
  113. /* upper_tri_solve(dense_matrix<T>, std::vector<T>, k, b) */
  114. /* lower_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
  115. /* upper_tri_solve(transposed(dense_matrix<T>), std::vector<T>, k, b) */
  116. /* lower_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
  117. /* upper_tri_solve(conjugated(dense_matrix<T>), std::vector<T>, k, b) */
  118. /* */
  119. /* rank_one_update(dense_matrix<T>, std::vector<T>, std::vector<T>) */
  120. /* rank_one_update(dense_matrix<T>, scaled(std::vector<T>), */
  121. /* std::vector<T>) */
  122. /* rank_one_update(dense_matrix<T>, std::vector<T>, */
  123. /* scaled(std::vector<T>)) */
  124. /* */
  125. /* ********************************************************************* */
  126. /* ********************************************************************* */
  127. /* Basic defines. */
  128. /* ********************************************************************* */
  129. # define BLAS_S float
  130. # define BLAS_D double
  131. # define BLAS_C std::complex<float>
  132. # define BLAS_Z std::complex<double>
  133. /* ********************************************************************* */
  134. /* BLAS functions used. */
  135. /* ********************************************************************* */
  136. extern "C" {
  137. void daxpy_(const int *n, const double *alpha, const double *x,
  138. const int *incx, double *y, const int *incy);
  139. void dgemm_(const char *tA, const char *tB, const int *m,
  140. const int *n, const int *k, const double *alpha,
  141. const double *A, const int *ldA, const double *B,
  142. const int *ldB, const double *beta, double *C,
  143. const int *ldC);
  144. void sgemm_(...); void cgemm_(...); void zgemm_(...);
  145. void sgemv_(...); void dgemv_(...); void cgemv_(...); void zgemv_(...);
  146. void strsv_(...); void dtrsv_(...); void ctrsv_(...); void ztrsv_(...);
  147. void saxpy_(...); /*void daxpy_(...); */void caxpy_(...); void zaxpy_(...);
  148. BLAS_S sdot_ (...); BLAS_D ddot_ (...);
  149. BLAS_C cdotu_(...); BLAS_Z zdotu_(...);
  150. BLAS_C cdotc_(...); BLAS_Z zdotc_(...);
  151. BLAS_S snrm2_(...); BLAS_D dnrm2_(...);
  152. BLAS_S scnrm2_(...); BLAS_D dznrm2_(...);
  153. void sger_(...); void dger_(...); void cgerc_(...); void zgerc_(...);
  154. }
  155. /* ********************************************************************* */
  156. /* vect_norm2(x). */
  157. /* ********************************************************************* */
  158. # define nrm2_interface(param1, trans1, blas_name, base_type) \
  159. inline number_traits<base_type >::magnitude_type \
  160. vect_norm2(param1(base_type)) { \
  161. GMMLAPACK_TRACE("nrm2_interface"); \
  162. int inc(1), n(int(vect_size(x))); trans1(base_type); \
  163. return blas_name(&n, &x[0], &inc); \
  164. }
  165. # define nrm2_p1(base_type) const std::vector<base_type > &x
  166. # define nrm2_trans1(base_type)
  167. nrm2_interface(nrm2_p1, nrm2_trans1, snrm2_ , BLAS_S)
  168. nrm2_interface(nrm2_p1, nrm2_trans1, dnrm2_ , BLAS_D)
  169. nrm2_interface(nrm2_p1, nrm2_trans1, scnrm2_, BLAS_C)
  170. nrm2_interface(nrm2_p1, nrm2_trans1, dznrm2_, BLAS_Z)
  171. /* ********************************************************************* */
  172. /* vect_sp(x, y). */
  173. /* ********************************************************************* */
  174. # define dot_interface(param1, trans1, mult1, param2, trans2, mult2, \
  175. blas_name, base_type) \
  176. inline base_type vect_sp(param1(base_type), param2(base_type)) { \
  177. GMMLAPACK_TRACE("dot_interface"); \
  178. trans1(base_type); trans2(base_type); int inc(1), n(int(vect_size(y)));\
  179. return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
  180. }
  181. # define dot_p1(base_type) const std::vector<base_type > &x
  182. # define dot_trans1(base_type)
  183. # define dot_p1_s(base_type) \
  184. const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
  185. # define dot_trans1_s(base_type) \
  186. std::vector<base_type > &x = \
  187. const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
  188. base_type a(x_.r)
  189. # define dot_p2(base_type) const std::vector<base_type > &y
  190. # define dot_trans2(base_type)
  191. # define dot_p2_s(base_type) \
  192. const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
  193. # define dot_trans2_s(base_type) \
  194. std::vector<base_type > &y = \
  195. const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
  196. base_type b(y_.r)
  197. dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2, dot_trans2, (BLAS_S),
  198. sdot_ , BLAS_S)
  199. dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2, dot_trans2, (BLAS_D),
  200. ddot_ , BLAS_D)
  201. dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2, dot_trans2, (BLAS_C),
  202. cdotu_, BLAS_C)
  203. dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2, dot_trans2, (BLAS_Z),
  204. zdotu_, BLAS_Z)
  205. dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_S),
  206. sdot_ ,BLAS_S)
  207. dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_D),
  208. ddot_ ,BLAS_D)
  209. dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_C),
  210. cdotu_,BLAS_C)
  211. dot_interface(dot_p1_s, dot_trans1_s, a*, dot_p2, dot_trans2, (BLAS_Z),
  212. zdotu_,BLAS_Z)
  213. dot_interface(dot_p1, dot_trans1, (BLAS_S), dot_p2_s, dot_trans2_s, b*,
  214. sdot_ ,BLAS_S)
  215. dot_interface(dot_p1, dot_trans1, (BLAS_D), dot_p2_s, dot_trans2_s, b*,
  216. ddot_ ,BLAS_D)
  217. dot_interface(dot_p1, dot_trans1, (BLAS_C), dot_p2_s, dot_trans2_s, b*,
  218. cdotu_,BLAS_C)
  219. dot_interface(dot_p1, dot_trans1, (BLAS_Z), dot_p2_s, dot_trans2_s, b*,
  220. zdotu_,BLAS_Z)
  221. dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,sdot_ ,
  222. BLAS_S)
  223. dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,ddot_ ,
  224. BLAS_D)
  225. dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,cdotu_,
  226. BLAS_C)
  227. dot_interface(dot_p1_s,dot_trans1_s,a*,dot_p2_s,dot_trans2_s,b*,zdotu_,
  228. BLAS_Z)
  229. /* ********************************************************************* */
  230. /* vect_hp(x, y). */
  231. /* ********************************************************************* */
  232. # define dotc_interface(param1, trans1, mult1, param2, trans2, mult2, \
  233. blas_name, base_type) \
  234. inline base_type vect_hp(param1(base_type), param2(base_type)) { \
  235. GMMLAPACK_TRACE("dotc_interface"); \
  236. trans1(base_type); trans2(base_type); int inc(1), n(int(vect_size(y)));\
  237. return mult1 mult2 blas_name(&n, &x[0], &inc, &y[0], &inc); \
  238. }
  239. # define dotc_p1(base_type) const std::vector<base_type > &x
  240. # define dotc_trans1(base_type)
  241. # define dotc_p1_s(base_type) \
  242. const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
  243. # define dotc_trans1_s(base_type) \
  244. std::vector<base_type > &x = \
  245. const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
  246. base_type a(x_.r)
  247. # define dotc_p2(base_type) const std::vector<base_type > &y
  248. # define dotc_trans2(base_type)
  249. # define dotc_p2_s(base_type) \
  250. const scaled_vector_const_ref<std::vector<base_type >, base_type > &y_
  251. # define dotc_trans2_s(base_type) \
  252. std::vector<base_type > &y = \
  253. const_cast<std::vector<base_type > &>(*(linalg_origin(y_))); \
  254. base_type b(gmm::conj(y_.r))
  255. dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2, dotc_trans2,
  256. (BLAS_S),sdot_ ,BLAS_S)
  257. dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2, dotc_trans2,
  258. (BLAS_D),ddot_ ,BLAS_D)
  259. dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2, dotc_trans2,
  260. (BLAS_C),cdotc_,BLAS_C)
  261. dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2, dotc_trans2,
  262. (BLAS_Z),zdotc_,BLAS_Z)
  263. dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
  264. (BLAS_S),sdot_, BLAS_S)
  265. dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
  266. (BLAS_D),ddot_ , BLAS_D)
  267. dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
  268. (BLAS_C),cdotc_, BLAS_C)
  269. dotc_interface(dotc_p1_s, dotc_trans1_s, a*, dotc_p2, dotc_trans2,
  270. (BLAS_Z),zdotc_, BLAS_Z)
  271. dotc_interface(dotc_p1, dotc_trans1, (BLAS_S), dotc_p2_s, dotc_trans2_s,
  272. b*,sdot_ , BLAS_S)
  273. dotc_interface(dotc_p1, dotc_trans1, (BLAS_D), dotc_p2_s, dotc_trans2_s,
  274. b*,ddot_ , BLAS_D)
  275. dotc_interface(dotc_p1, dotc_trans1, (BLAS_C), dotc_p2_s, dotc_trans2_s,
  276. b*,cdotc_, BLAS_C)
  277. dotc_interface(dotc_p1, dotc_trans1, (BLAS_Z), dotc_p2_s, dotc_trans2_s,
  278. b*,zdotc_, BLAS_Z)
  279. dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,sdot_ ,
  280. BLAS_S)
  281. dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,ddot_ ,
  282. BLAS_D)
  283. dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,cdotc_,
  284. BLAS_C)
  285. dotc_interface(dotc_p1_s,dotc_trans1_s,a*,dotc_p2_s,dotc_trans2_s,b*,zdotc_,
  286. BLAS_Z)
  287. /* ********************************************************************* */
  288. /* add(x, y). */
  289. /* ********************************************************************* */
  290. # define axpy_interface(param1, trans1, blas_name, base_type) \
  291. inline void add(param1(base_type), std::vector<base_type > &y) { \
  292. GMMLAPACK_TRACE("axpy_interface"); \
  293. int inc(1), n(int(vect_size(y))); trans1(base_type); \
  294. blas_name(&n, &a, &x[0], &inc, &y[0], &inc); \
  295. }
  296. # define axpy_p1(base_type) const std::vector<base_type > &x
  297. # define axpy_trans1(base_type) base_type a(1)
  298. # define axpy_p1_s(base_type) \
  299. const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
  300. # define axpy_trans1_s(base_type) \
  301. std::vector<base_type > &x = \
  302. const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
  303. base_type a(x_.r)
  304. axpy_interface(axpy_p1, axpy_trans1, saxpy_, BLAS_S)
  305. axpy_interface(axpy_p1, axpy_trans1, daxpy_, BLAS_D)
  306. axpy_interface(axpy_p1, axpy_trans1, caxpy_, BLAS_C)
  307. axpy_interface(axpy_p1, axpy_trans1, zaxpy_, BLAS_Z)
  308. axpy_interface(axpy_p1_s, axpy_trans1_s, saxpy_, BLAS_S)
  309. axpy_interface(axpy_p1_s, axpy_trans1_s, daxpy_, BLAS_D)
  310. axpy_interface(axpy_p1_s, axpy_trans1_s, caxpy_, BLAS_C)
  311. axpy_interface(axpy_p1_s, axpy_trans1_s, zaxpy_, BLAS_Z)
  312. /* ********************************************************************* */
  313. /* mult_add(A, x, z). */
  314. /* ********************************************************************* */
  315. # define gemv_interface(param1, trans1, param2, trans2, blas_name, \
  316. base_type, orien) \
  317. inline void mult_add_spec(param1(base_type), param2(base_type), \
  318. std::vector<base_type > &z, orien) { \
  319. GMMLAPACK_TRACE("gemv_interface"); \
  320. trans1(base_type); trans2(base_type); base_type beta(1); \
  321. int m(int(mat_nrows(A))), lda(m), n(int(mat_ncols(A))), inc(1); \
  322. if (m && n) blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, \
  323. &beta, &z[0], &inc); \
  324. else gmm::clear(z); \
  325. }
  326. // First parameter
  327. # define gem_p1_n(base_type) const dense_matrix<base_type > &A
  328. # define gem_trans1_n(base_type) const char t = 'N'
  329. # define gem_p1_t(base_type) \
  330. const transposed_col_ref<dense_matrix<base_type > *> &A_
  331. # define gem_trans1_t(base_type) dense_matrix<base_type > &A = \
  332. const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  333. const char t = 'T'
  334. # define gem_p1_tc(base_type) \
  335. const transposed_col_ref<const dense_matrix<base_type > *> &A_
  336. # define gem_p1_c(base_type) \
  337. const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_
  338. # define gem_trans1_c(base_type) dense_matrix<base_type > &A = \
  339. const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  340. const char t = 'C'
  341. // second parameter
  342. # define gemv_p2_n(base_type) const std::vector<base_type > &x
  343. # define gemv_trans2_n(base_type) base_type alpha(1)
  344. # define gemv_p2_s(base_type) \
  345. const scaled_vector_const_ref<std::vector<base_type >, base_type > &x_
  346. # define gemv_trans2_s(base_type) std::vector<base_type > &x = \
  347. const_cast<std::vector<base_type > &>(*(linalg_origin(x_))); \
  348. base_type alpha(x_.r)
  349. // Z <- AX + Z.
  350. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
  351. BLAS_S, col_major)
  352. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
  353. BLAS_D, col_major)
  354. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
  355. BLAS_C, col_major)
  356. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
  357. BLAS_Z, col_major)
  358. // Z <- transposed(A)X + Z.
  359. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
  360. BLAS_S, row_major)
  361. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
  362. BLAS_D, row_major)
  363. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
  364. BLAS_C, row_major)
  365. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
  366. BLAS_Z, row_major)
  367. // Z <- transposed(const A)X + Z.
  368. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
  369. BLAS_S, row_major)
  370. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
  371. BLAS_D, row_major)
  372. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
  373. BLAS_C, row_major)
  374. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
  375. BLAS_Z, row_major)
  376. // Z <- conjugated(A)X + Z.
  377. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
  378. BLAS_S, row_major)
  379. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
  380. BLAS_D, row_major)
  381. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
  382. BLAS_C, row_major)
  383. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
  384. BLAS_Z, row_major)
  385. // Z <- A scaled(X) + Z.
  386. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
  387. BLAS_S, col_major)
  388. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
  389. BLAS_D, col_major)
  390. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
  391. BLAS_C, col_major)
  392. gemv_interface(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
  393. BLAS_Z, col_major)
  394. // Z <- transposed(A) scaled(X) + Z.
  395. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
  396. BLAS_S, row_major)
  397. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
  398. BLAS_D, row_major)
  399. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
  400. BLAS_C, row_major)
  401. gemv_interface(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
  402. BLAS_Z, row_major)
  403. // Z <- transposed(const A) scaled(X) + Z.
  404. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
  405. BLAS_S, row_major)
  406. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
  407. BLAS_D, row_major)
  408. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
  409. BLAS_C, row_major)
  410. gemv_interface(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
  411. BLAS_Z, row_major)
  412. // Z <- conjugated(A) scaled(X) + Z.
  413. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
  414. BLAS_S, row_major)
  415. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
  416. BLAS_D, row_major)
  417. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
  418. BLAS_C, row_major)
  419. gemv_interface(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
  420. BLAS_Z, row_major)
  421. /* ********************************************************************* */
  422. /* mult(A, x, y). */
  423. /* ********************************************************************* */
  424. # define gemv_interface2(param1, trans1, param2, trans2, blas_name, \
  425. base_type, orien) \
  426. inline void mult_spec(param1(base_type), param2(base_type), \
  427. std::vector<base_type > &z, orien) { \
  428. GMMLAPACK_TRACE("gemv_interface2"); \
  429. trans1(base_type); trans2(base_type); base_type beta(0); \
  430. int m(int(mat_nrows(A))), lda(m), n(int(mat_ncols(A))), inc(1); \
  431. if (m && n) \
  432. blas_name(&t, &m, &n, &alpha, &A(0,0), &lda, &x[0], &inc, &beta, \
  433. &z[0], &inc); \
  434. else gmm::clear(z); \
  435. }
  436. // Y <- AX.
  437. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, sgemv_,
  438. BLAS_S, col_major)
  439. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, dgemv_,
  440. BLAS_D, col_major)
  441. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, cgemv_,
  442. BLAS_C, col_major)
  443. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_n, gemv_trans2_n, zgemv_,
  444. BLAS_Z, col_major)
  445. // Y <- transposed(A)X.
  446. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
  447. BLAS_S, row_major)
  448. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
  449. BLAS_D, row_major)
  450. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
  451. BLAS_C, row_major)
  452. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
  453. BLAS_Z, row_major)
  454. // Y <- transposed(const A)X.
  455. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, sgemv_,
  456. BLAS_S, row_major)
  457. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, dgemv_,
  458. BLAS_D, row_major)
  459. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, cgemv_,
  460. BLAS_C, row_major)
  461. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_n, gemv_trans2_n, zgemv_,
  462. BLAS_Z, row_major)
  463. // Y <- conjugated(A)X.
  464. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, sgemv_,
  465. BLAS_S, row_major)
  466. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, dgemv_,
  467. BLAS_D, row_major)
  468. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, cgemv_,
  469. BLAS_C, row_major)
  470. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_n, gemv_trans2_n, zgemv_,
  471. BLAS_Z, row_major)
  472. // Y <- A scaled(X).
  473. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, sgemv_,
  474. BLAS_S, col_major)
  475. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, dgemv_,
  476. BLAS_D, col_major)
  477. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, cgemv_,
  478. BLAS_C, col_major)
  479. gemv_interface2(gem_p1_n, gem_trans1_n, gemv_p2_s, gemv_trans2_s, zgemv_,
  480. BLAS_Z, col_major)
  481. // Y <- transposed(A) scaled(X).
  482. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
  483. BLAS_S, row_major)
  484. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
  485. BLAS_D, row_major)
  486. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
  487. BLAS_C, row_major)
  488. gemv_interface2(gem_p1_t, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
  489. BLAS_Z, row_major)
  490. // Y <- transposed(const A) scaled(X).
  491. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, sgemv_,
  492. BLAS_S, row_major)
  493. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, dgemv_,
  494. BLAS_D, row_major)
  495. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, cgemv_,
  496. BLAS_C, row_major)
  497. gemv_interface2(gem_p1_tc, gem_trans1_t, gemv_p2_s, gemv_trans2_s, zgemv_,
  498. BLAS_Z, row_major)
  499. // Y <- conjugated(A) scaled(X).
  500. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, sgemv_,
  501. BLAS_S, row_major)
  502. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, dgemv_,
  503. BLAS_D, row_major)
  504. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, cgemv_,
  505. BLAS_C, row_major)
  506. gemv_interface2(gem_p1_c, gem_trans1_c, gemv_p2_s, gemv_trans2_s, zgemv_,
  507. BLAS_Z, row_major)
  508. /* ********************************************************************* */
  509. /* Rank one update. */
  510. /* ********************************************************************* */
  511. # define ger_interface(blas_name, base_type) \
  512. inline void rank_one_update(const dense_matrix<base_type > &A, \
  513. const std::vector<base_type > &V, \
  514. const std::vector<base_type > &W) { \
  515. GMMLAPACK_TRACE("ger_interface"); \
  516. int m(int(mat_nrows(A))), lda = m, n(int(mat_ncols(A))); \
  517. int incx = 1, incy = 1; \
  518. base_type alpha(1); \
  519. if (m && n) \
  520. blas_name(&m, &n, &alpha, &V[0], &incx, &W[0], &incy, &A(0,0), &lda);\
  521. }
  522. ger_interface(sger_, BLAS_S)
  523. ger_interface(dger_, BLAS_D)
  524. ger_interface(cgerc_, BLAS_C)
  525. ger_interface(zgerc_, BLAS_Z)
  526. # define ger_interface_sn(blas_name, base_type) \
  527. inline void rank_one_update(const dense_matrix<base_type > &A, \
  528. gemv_p2_s(base_type), \
  529. const std::vector<base_type > &W) { \
  530. GMMLAPACK_TRACE("ger_interface"); \
  531. gemv_trans2_s(base_type); \
  532. int m(int(mat_nrows(A))), lda = m, n(int(mat_ncols(A))); \
  533. int incx = 1, incy = 1; \
  534. if (m && n) \
  535. blas_name(&m, &n, &alpha, &x[0], &incx, &W[0], &incy, &A(0,0), &lda);\
  536. }
  537. ger_interface_sn(sger_, BLAS_S)
  538. ger_interface_sn(dger_, BLAS_D)
  539. ger_interface_sn(cgerc_, BLAS_C)
  540. ger_interface_sn(zgerc_, BLAS_Z)
  541. # define ger_interface_ns(blas_name, base_type) \
  542. inline void rank_one_update(const dense_matrix<base_type > &A, \
  543. const std::vector<base_type > &V, \
  544. gemv_p2_s(base_type)) { \
  545. GMMLAPACK_TRACE("ger_interface"); \
  546. gemv_trans2_s(base_type); \
  547. int m(int(mat_nrows(A))), lda = m, n(int(mat_ncols(A))); \
  548. int incx = 1, incy = 1; \
  549. base_type al2 = gmm::conj(alpha); \
  550. if (m && n) \
  551. blas_name(&m, &n, &al2, &V[0], &incx, &x[0], &incy, &A(0,0), &lda); \
  552. }
  553. ger_interface_ns(sger_, BLAS_S)
  554. ger_interface_ns(dger_, BLAS_D)
  555. ger_interface_ns(cgerc_, BLAS_C)
  556. ger_interface_ns(zgerc_, BLAS_Z)
  557. /* ********************************************************************* */
  558. /* dense matrix x dense matrix multiplication. */
  559. /* ********************************************************************* */
  560. # define gemm_interface_nn(blas_name, base_type) \
  561. inline void mult_spec(const dense_matrix<base_type > &A, \
  562. const dense_matrix<base_type > &B, \
  563. dense_matrix<base_type > &C, c_mult) { \
  564. GMMLAPACK_TRACE("gemm_interface_nn"); \
  565. const char t = 'N'; \
  566. int m(int(mat_nrows(A))), lda = m, k(int(mat_ncols(A))); \
  567. int n(int(mat_ncols(B))); \
  568. int ldb = k, ldc = m; \
  569. base_type alpha(1), beta(0); \
  570. if (m && k && n) \
  571. blas_name(&t, &t, &m, &n, &k, &alpha, \
  572. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  573. else gmm::clear(C); \
  574. }
  575. gemm_interface_nn(sgemm_, BLAS_S)
  576. gemm_interface_nn(dgemm_, BLAS_D)
  577. gemm_interface_nn(cgemm_, BLAS_C)
  578. gemm_interface_nn(zgemm_, BLAS_Z)
  579. /* ********************************************************************* */
  580. /* transposed(dense matrix) x dense matrix multiplication. */
  581. /* ********************************************************************* */
  582. # define gemm_interface_tn(blas_name, base_type, is_const) \
  583. inline void mult_spec( \
  584. const transposed_col_ref<is_const<base_type > *> &A_,\
  585. const dense_matrix<base_type > &B, \
  586. dense_matrix<base_type > &C, rcmult) { \
  587. GMMLAPACK_TRACE("gemm_interface_tn"); \
  588. dense_matrix<base_type > &A \
  589. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  590. const char t = 'T', u = 'N'; \
  591. int m(int(mat_ncols(A))), k(int(mat_nrows(A))), n(int(mat_ncols(B))); \
  592. int lda = k, ldb = k, ldc = m; \
  593. base_type alpha(1), beta(0); \
  594. if (m && k && n) \
  595. blas_name(&t, &u, &m, &n, &k, &alpha, \
  596. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  597. else gmm::clear(C); \
  598. }
  599. gemm_interface_tn(sgemm_, BLAS_S, dense_matrix)
  600. gemm_interface_tn(dgemm_, BLAS_D, dense_matrix)
  601. gemm_interface_tn(cgemm_, BLAS_C, dense_matrix)
  602. gemm_interface_tn(zgemm_, BLAS_Z, dense_matrix)
  603. gemm_interface_tn(sgemm_, BLAS_S, const dense_matrix)
  604. gemm_interface_tn(dgemm_, BLAS_D, const dense_matrix)
  605. gemm_interface_tn(cgemm_, BLAS_C, const dense_matrix)
  606. gemm_interface_tn(zgemm_, BLAS_Z, const dense_matrix)
  607. /* ********************************************************************* */
  608. /* dense matrix x transposed(dense matrix) multiplication. */
  609. /* ********************************************************************* */
  610. # define gemm_interface_nt(blas_name, base_type, is_const) \
  611. inline void mult_spec(const dense_matrix<base_type > &A, \
  612. const transposed_col_ref<is_const<base_type > *> &B_,\
  613. dense_matrix<base_type > &C, r_mult) { \
  614. GMMLAPACK_TRACE("gemm_interface_nt"); \
  615. dense_matrix<base_type > &B \
  616. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
  617. const char t = 'N', u = 'T'; \
  618. int m(int(mat_nrows(A))), lda = m, k(int(mat_ncols(A))); \
  619. int n(int(mat_nrows(B))); \
  620. int ldb = n, ldc = m; \
  621. base_type alpha(1), beta(0); \
  622. if (m && k && n) \
  623. blas_name(&t, &u, &m, &n, &k, &alpha, \
  624. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  625. else gmm::clear(C); \
  626. }
  627. gemm_interface_nt(sgemm_, BLAS_S, dense_matrix)
  628. gemm_interface_nt(dgemm_, BLAS_D, dense_matrix)
  629. gemm_interface_nt(cgemm_, BLAS_C, dense_matrix)
  630. gemm_interface_nt(zgemm_, BLAS_Z, dense_matrix)
  631. gemm_interface_nt(sgemm_, BLAS_S, const dense_matrix)
  632. gemm_interface_nt(dgemm_, BLAS_D, const dense_matrix)
  633. gemm_interface_nt(cgemm_, BLAS_C, const dense_matrix)
  634. gemm_interface_nt(zgemm_, BLAS_Z, const dense_matrix)
  635. /* ********************************************************************* */
  636. /* transposed(dense matrix) x transposed(dense matrix) multiplication. */
  637. /* ********************************************************************* */
  638. # define gemm_interface_tt(blas_name, base_type, isA_const, isB_const) \
  639. inline void mult_spec( \
  640. const transposed_col_ref<isA_const <base_type > *> &A_,\
  641. const transposed_col_ref<isB_const <base_type > *> &B_,\
  642. dense_matrix<base_type > &C, r_mult) { \
  643. GMMLAPACK_TRACE("gemm_interface_tt"); \
  644. dense_matrix<base_type > &A \
  645. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  646. dense_matrix<base_type > &B \
  647. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
  648. const char t = 'T', u = 'T'; \
  649. int m(int(mat_ncols(A))), k(int(mat_nrows(A))), n(int(mat_nrows(B))); \
  650. int lda = k, ldb = n, ldc = m; \
  651. base_type alpha(1), beta(0); \
  652. if (m && k && n) \
  653. blas_name(&t, &u, &m, &n, &k, &alpha, \
  654. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  655. else gmm::clear(C); \
  656. }
  657. gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, dense_matrix)
  658. gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, dense_matrix)
  659. gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, dense_matrix)
  660. gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, dense_matrix)
  661. gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, dense_matrix)
  662. gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, dense_matrix)
  663. gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, dense_matrix)
  664. gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, dense_matrix)
  665. gemm_interface_tt(sgemm_, BLAS_S, dense_matrix, const dense_matrix)
  666. gemm_interface_tt(dgemm_, BLAS_D, dense_matrix, const dense_matrix)
  667. gemm_interface_tt(cgemm_, BLAS_C, dense_matrix, const dense_matrix)
  668. gemm_interface_tt(zgemm_, BLAS_Z, dense_matrix, const dense_matrix)
  669. gemm_interface_tt(sgemm_, BLAS_S, const dense_matrix, const dense_matrix)
  670. gemm_interface_tt(dgemm_, BLAS_D, const dense_matrix, const dense_matrix)
  671. gemm_interface_tt(cgemm_, BLAS_C, const dense_matrix, const dense_matrix)
  672. gemm_interface_tt(zgemm_, BLAS_Z, const dense_matrix, const dense_matrix)
  673. /* ********************************************************************* */
  674. /* conjugated(dense matrix) x dense matrix multiplication. */
  675. /* ********************************************************************* */
  676. # define gemm_interface_cn(blas_name, base_type) \
  677. inline void mult_spec( \
  678. const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
  679. const dense_matrix<base_type > &B, \
  680. dense_matrix<base_type > &C, rcmult) { \
  681. GMMLAPACK_TRACE("gemm_interface_cn"); \
  682. dense_matrix<base_type > &A \
  683. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  684. const char t = 'C', u = 'N'; \
  685. int m(int(mat_ncols(A))), k(int(mat_nrows(A))), n(int(mat_ncols(B))); \
  686. int lda = k, ldb = k, ldc = m; \
  687. base_type alpha(1), beta(0); \
  688. if (m && k && n) \
  689. blas_name(&t, &u, &m, &n, &k, &alpha, \
  690. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  691. else gmm::clear(C); \
  692. }
  693. gemm_interface_cn(sgemm_, BLAS_S)
  694. gemm_interface_cn(dgemm_, BLAS_D)
  695. gemm_interface_cn(cgemm_, BLAS_C)
  696. gemm_interface_cn(zgemm_, BLAS_Z)
  697. /* ********************************************************************* */
  698. /* dense matrix x conjugated(dense matrix) multiplication. */
  699. /* ********************************************************************* */
  700. # define gemm_interface_nc(blas_name, base_type) \
  701. inline void mult_spec(const dense_matrix<base_type > &A, \
  702. const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
  703. dense_matrix<base_type > &C, c_mult, row_major) { \
  704. GMMLAPACK_TRACE("gemm_interface_nc"); \
  705. dense_matrix<base_type > &B \
  706. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
  707. const char t = 'N', u = 'C'; \
  708. int m(int(mat_nrows(A))), lda = m, k(int(mat_ncols(A))); \
  709. int n(int(mat_nrows(B))), ldb = n, ldc = m; \
  710. base_type alpha(1), beta(0); \
  711. if (m && k && n) \
  712. blas_name(&t, &u, &m, &n, &k, &alpha, \
  713. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  714. else gmm::clear(C); \
  715. }
  716. gemm_interface_nc(sgemm_, BLAS_S)
  717. gemm_interface_nc(dgemm_, BLAS_D)
  718. gemm_interface_nc(cgemm_, BLAS_C)
  719. gemm_interface_nc(zgemm_, BLAS_Z)
  720. /* ********************************************************************* */
  721. /* conjugated(dense matrix) x conjugated(dense matrix) multiplication. */
  722. /* ********************************************************************* */
  723. # define gemm_interface_cc(blas_name, base_type) \
  724. inline void mult_spec( \
  725. const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &A_,\
  726. const conjugated_col_matrix_const_ref<dense_matrix<base_type > > &B_,\
  727. dense_matrix<base_type > &C, r_mult) { \
  728. GMMLAPACK_TRACE("gemm_interface_cc"); \
  729. dense_matrix<base_type > &A \
  730. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(A_))); \
  731. dense_matrix<base_type > &B \
  732. = const_cast<dense_matrix<base_type > &>(*(linalg_origin(B_))); \
  733. const char t = 'C', u = 'C'; \
  734. int m(int(mat_ncols(A))), k(int(mat_nrows(A))), lda = k; \
  735. int n(int(mat_nrows(B))), ldb = n, ldc = m; \
  736. base_type alpha(1), beta(0); \
  737. if (m && k && n) \
  738. blas_name(&t, &u, &m, &n, &k, &alpha, \
  739. &A(0,0), &lda, &B(0,0), &ldb, &beta, &C(0,0), &ldc); \
  740. else gmm::clear(C); \
  741. }
  742. gemm_interface_cc(sgemm_, BLAS_S)
  743. gemm_interface_cc(dgemm_, BLAS_D)
  744. gemm_interface_cc(cgemm_, BLAS_C)
  745. gemm_interface_cc(zgemm_, BLAS_Z)
  746. /* ********************************************************************* */
  747. /* Tri solve. */
  748. /* ********************************************************************* */
  749. # define trsv_interface(f_name, loru, param1, trans1, blas_name, base_type)\
  750. inline void f_name(param1(base_type), std::vector<base_type > &x, \
  751. size_type k, bool is_unit) { \
  752. GMMLAPACK_TRACE("trsv_interface"); \
  753. loru; trans1(base_type); char d = is_unit ? 'U' : 'N'; \
  754. int lda(int(mat_nrows(A))), inc(1), n = int(k); \
  755. if (lda) blas_name(&l, &t, &d, &n, &A(0,0), &lda, &x[0], &inc); \
  756. }
  757. # define trsv_upper const char l = 'U'
  758. # define trsv_lower const char l = 'L'
  759. // X <- LOWER(A)^{-1}X.
  760. trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
  761. strsv_, BLAS_S)
  762. trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
  763. dtrsv_, BLAS_D)
  764. trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
  765. ctrsv_, BLAS_C)
  766. trsv_interface(lower_tri_solve, trsv_lower, gem_p1_n, gem_trans1_n,
  767. ztrsv_, BLAS_Z)
  768. // X <- UPPER(A)^{-1}X.
  769. trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
  770. strsv_, BLAS_S)
  771. trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
  772. dtrsv_, BLAS_D)
  773. trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
  774. ctrsv_, BLAS_C)
  775. trsv_interface(upper_tri_solve, trsv_upper, gem_p1_n, gem_trans1_n,
  776. ztrsv_, BLAS_Z)
  777. // X <- LOWER(transposed(A))^{-1}X.
  778. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
  779. strsv_, BLAS_S)
  780. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
  781. dtrsv_, BLAS_D)
  782. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
  783. ctrsv_, BLAS_C)
  784. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_t, gem_trans1_t,
  785. ztrsv_, BLAS_Z)
  786. // X <- UPPER(transposed(A))^{-1}X.
  787. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
  788. strsv_, BLAS_S)
  789. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
  790. dtrsv_, BLAS_D)
  791. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
  792. ctrsv_, BLAS_C)
  793. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_t, gem_trans1_t,
  794. ztrsv_, BLAS_Z)
  795. // X <- LOWER(transposed(const A))^{-1}X.
  796. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
  797. strsv_, BLAS_S)
  798. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
  799. dtrsv_, BLAS_D)
  800. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
  801. ctrsv_, BLAS_C)
  802. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_tc, gem_trans1_t,
  803. ztrsv_, BLAS_Z)
  804. // X <- UPPER(transposed(const A))^{-1}X.
  805. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
  806. strsv_, BLAS_S)
  807. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
  808. dtrsv_, BLAS_D)
  809. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
  810. ctrsv_, BLAS_C)
  811. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_tc, gem_trans1_t,
  812. ztrsv_, BLAS_Z)
  813. // X <- LOWER(conjugated(A))^{-1}X.
  814. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
  815. strsv_, BLAS_S)
  816. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
  817. dtrsv_, BLAS_D)
  818. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
  819. ctrsv_, BLAS_C)
  820. trsv_interface(lower_tri_solve, trsv_upper, gem_p1_c, gem_trans1_c,
  821. ztrsv_, BLAS_Z)
  822. // X <- UPPER(conjugated(A))^{-1}X.
  823. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
  824. strsv_, BLAS_S)
  825. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
  826. dtrsv_, BLAS_D)
  827. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
  828. ctrsv_, BLAS_C)
  829. trsv_interface(upper_tri_solve, trsv_lower, gem_p1_c, gem_trans1_c,
  830. ztrsv_, BLAS_Z)
  831. }
  832. #endif // GMM_BLAS_INTERFACE_H
  833. #endif // GMM_USES_BLAS