You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

410 lines
16 KiB

  1. /* -*- c++ -*- (enables emacs c++ mode) */
  2. /*===========================================================================
  3. Copyright (C) 2003-2015 Yves Renard
  4. This file is a part of GETFEM++
  5. Getfem++ is free software; you can redistribute it and/or modify it
  6. under the terms of the GNU Lesser General Public License as published
  7. by the Free Software Foundation; either version 3 of the License, or
  8. (at your option) any later version along with the GCC Runtime Library
  9. Exception either version 3.1 or (at your option) any later version.
  10. This program is distributed in the hope that it will be useful, but
  11. WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
  12. or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
  13. License and GCC Runtime Library Exception for more details.
  14. You should have received a copy of the GNU Lesser General Public License
  15. along with this program; if not, write to the Free Software Foundation,
  16. Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA.
  17. As a special exception, you may use this file as it is a part of a free
  18. software library without restriction. Specifically, if other files
  19. instantiate templates or use macros or inline functions from this file,
  20. or you compile this file and link it with other files to produce an
  21. executable, this file does not by itself cause the resulting executable
  22. to be covered by the GNU Lesser General Public License. This exception
  23. does not however invalidate any other reasons why the executable file
  24. might be covered by the GNU Lesser General Public License.
  25. ===========================================================================*/
  26. /**@file gmm_superlu_interface.h
  27. @author Yves Renard <Yves.Renard@insa-lyon.fr>
  28. @date October 17, 2003.
  29. @brief Interface with SuperLU (LU direct solver for sparse matrices).
  30. */
  31. #if defined(GMM_USES_SUPERLU) && !defined(GETFEM_VERSION)
  32. #ifndef GMM_SUPERLU_INTERFACE_H
  33. #define GMM_SUPERLU_INTERFACE_H
  34. #include "gmm_kernel.h"
  35. typedef int int_t;
  36. /* because SRC/util.h defines TRUE and FALSE ... */
  37. #ifdef TRUE
  38. # undef TRUE
  39. #endif
  40. #ifdef FALSE
  41. # undef FALSE
  42. #endif
  43. #include "superlu/slu_Cnames.h"
  44. #include "superlu/supermatrix.h"
  45. #include "superlu/slu_util.h"
  46. namespace SuperLU_S {
  47. #include "superlu/slu_sdefs.h"
  48. }
  49. namespace SuperLU_D {
  50. #include "superlu/slu_ddefs.h"
  51. }
  52. namespace SuperLU_C {
  53. #include "superlu/slu_cdefs.h"
  54. }
  55. namespace SuperLU_Z {
  56. #include "superlu/slu_zdefs.h"
  57. }
  58. namespace gmm {
  59. /* interface for Create_CompCol_Matrix */
  60. inline void Create_CompCol_Matrix(SuperMatrix *A, int m, int n, int nnz,
  61. float *a, int *ir, int *jc) {
  62. SuperLU_S::sCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
  63. SLU_NC, SLU_S, SLU_GE);
  64. }
  65. inline void Create_CompCol_Matrix(SuperMatrix *A, int m, int n, int nnz,
  66. double *a, int *ir, int *jc) {
  67. SuperLU_D::dCreate_CompCol_Matrix(A, m, n, nnz, a, ir, jc,
  68. SLU_NC, SLU_D, SLU_GE);
  69. }
  70. inline void Create_CompCol_Matrix(SuperMatrix *A, int m, int n, int nnz,
  71. std::complex<float> *a, int *ir, int *jc) {
  72. SuperLU_C::cCreate_CompCol_Matrix(A, m, n, nnz, (SuperLU_C::complex *)(a),
  73. ir, jc, SLU_NC, SLU_C, SLU_GE);
  74. }
  75. inline void Create_CompCol_Matrix(SuperMatrix *A, int m, int n, int nnz,
  76. std::complex<double> *a, int *ir, int *jc) {
  77. SuperLU_Z::zCreate_CompCol_Matrix(A, m, n, nnz,
  78. (SuperLU_Z::doublecomplex *)(a), ir, jc,
  79. SLU_NC, SLU_Z, SLU_GE);
  80. }
  81. /* interface for Create_Dense_Matrix */
  82. inline void Create_Dense_Matrix(SuperMatrix *A, int m, int n, float *a, int k)
  83. { SuperLU_S::sCreate_Dense_Matrix(A, m, n, a, k, SLU_DN, SLU_S, SLU_GE); }
  84. inline void Create_Dense_Matrix(SuperMatrix *A, int m, int n, double *a, int k)
  85. { SuperLU_D::dCreate_Dense_Matrix(A, m, n, a, k, SLU_DN, SLU_D, SLU_GE); }
  86. inline void Create_Dense_Matrix(SuperMatrix *A, int m, int n,
  87. std::complex<float> *a, int k) {
  88. SuperLU_C::cCreate_Dense_Matrix(A, m, n, (SuperLU_C::complex *)(a),
  89. k, SLU_DN, SLU_C, SLU_GE);
  90. }
  91. inline void Create_Dense_Matrix(SuperMatrix *A, int m, int n,
  92. std::complex<double> *a, int k) {
  93. SuperLU_Z::zCreate_Dense_Matrix(A, m, n, (SuperLU_Z::doublecomplex *)(a),
  94. k, SLU_DN, SLU_Z, SLU_GE);
  95. }
  96. /* interface for gssv */
  97. #define DECL_GSSV(NAMESPACE,FNAME,FLOATTYPE,KEYTYPE) \
  98. inline void SuperLU_gssv(superlu_options_t *options, SuperMatrix *A, int *p, \
  99. int *q, SuperMatrix *L, SuperMatrix *U, SuperMatrix *B, \
  100. SuperLUStat_t *stats, int *info, KEYTYPE) { \
  101. NAMESPACE::FNAME(options, A, p, q, L, U, B, stats, info); \
  102. }
  103. DECL_GSSV(SuperLU_S,sgssv,float,float)
  104. DECL_GSSV(SuperLU_C,cgssv,float,std::complex<float>)
  105. DECL_GSSV(SuperLU_D,dgssv,double,double)
  106. DECL_GSSV(SuperLU_Z,zgssv,double,std::complex<double>)
  107. /* interface for gssvx */
  108. #define DECL_GSSVX(NAMESPACE,FNAME,FLOATTYPE,KEYTYPE) \
  109. inline float SuperLU_gssvx(superlu_options_t *options, SuperMatrix *A, \
  110. int *perm_c, int *perm_r, int *etree, char *equed, \
  111. FLOATTYPE *R, FLOATTYPE *C, SuperMatrix *L, \
  112. SuperMatrix *U, void *work, int lwork, \
  113. SuperMatrix *B, SuperMatrix *X, \
  114. FLOATTYPE *recip_pivot_growth, \
  115. FLOATTYPE *rcond, FLOATTYPE *ferr, FLOATTYPE *berr, \
  116. SuperLUStat_t *stats, int *info, KEYTYPE) { \
  117. NAMESPACE::mem_usage_t mem_usage; \
  118. NAMESPACE::FNAME(options, A, perm_c, perm_r, etree, equed, R, C, L, \
  119. U, work, lwork, B, X, recip_pivot_growth, rcond, \
  120. ferr, berr, &mem_usage, stats, info); \
  121. return mem_usage.for_lu; /* bytes used by the factor storage */ \
  122. }
  123. DECL_GSSVX(SuperLU_S,sgssvx,float,float)
  124. DECL_GSSVX(SuperLU_C,cgssvx,float,std::complex<float>)
  125. DECL_GSSVX(SuperLU_D,dgssvx,double,double)
  126. DECL_GSSVX(SuperLU_Z,zgssvx,double,std::complex<double>)
  127. /* ********************************************************************* */
  128. /* SuperLU solve interface */
  129. /* ********************************************************************* */
  130. template <typename MAT, typename VECTX, typename VECTB>
  131. int SuperLU_solve(const MAT &A, const VECTX &X_, const VECTB &B,
  132. double& rcond_, int permc_spec = 3) {
  133. VECTX &X = const_cast<VECTX &>(X_);
  134. /*
  135. * Get column permutation vector perm_c[], according to permc_spec:
  136. * permc_spec = 0: use the natural ordering
  137. * permc_spec = 1: use minimum degree ordering on structure of A'*A
  138. * permc_spec = 2: use minimum degree ordering on structure of A'+A
  139. * permc_spec = 3: use approximate minimum degree column ordering
  140. */
  141. typedef typename linalg_traits<MAT>::value_type T;
  142. typedef typename number_traits<T>::magnitude_type R;
  143. int m = mat_nrows(A), n = mat_ncols(A), nrhs = 1, info = 0;
  144. csc_matrix<T> csc_A(m, n); gmm::copy(A, csc_A);
  145. std::vector<T> rhs(m), sol(m);
  146. gmm::copy(B, rhs);
  147. int nz = nnz(csc_A);
  148. if ((2 * nz / n) >= m)
  149. GMM_WARNING2("CAUTION : it seems that SuperLU has a problem"
  150. " for nearly dense sparse matrices");
  151. superlu_options_t options;
  152. set_default_options(&options);
  153. options.ColPerm = NATURAL;
  154. options.PrintStat = NO;
  155. options.ConditionNumber = YES;
  156. switch (permc_spec) {
  157. case 1 : options.ColPerm = MMD_ATA; break;
  158. case 2 : options.ColPerm = MMD_AT_PLUS_A; break;
  159. case 3 : options.ColPerm = COLAMD; break;
  160. }
  161. SuperLUStat_t stat;
  162. StatInit(&stat);
  163. SuperMatrix SA, SL, SU, SB, SX; // SuperLU format.
  164. Create_CompCol_Matrix(&SA, m, n, nz, (double *)(&(csc_A.pr[0])),
  165. (int *)(&(csc_A.ir[0])), (int *)(&(csc_A.jc[0])));
  166. Create_Dense_Matrix(&SB, m, nrhs, &rhs[0], m);
  167. Create_Dense_Matrix(&SX, m, nrhs, &sol[0], m);
  168. memset(&SL,0,sizeof SL);
  169. memset(&SU,0,sizeof SU);
  170. std::vector<int> etree(n);
  171. char equed[] = "B";
  172. std::vector<R> Rscale(m),Cscale(n); // row scale factors
  173. std::vector<R> ferr(nrhs), berr(nrhs);
  174. R recip_pivot_gross, rcond;
  175. std::vector<int> perm_r(m), perm_c(n);
  176. SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
  177. &etree[0] /* output */, equed /* output */,
  178. &Rscale[0] /* row scale factors (output) */,
  179. &Cscale[0] /* col scale factors (output) */,
  180. &SL /* fact L (output)*/, &SU /* fact U (output)*/,
  181. NULL /* work */,
  182. 0 /* lwork: superlu auto allocates (input) */,
  183. &SB /* rhs */, &SX /* solution */,
  184. &recip_pivot_gross /* reciprocal pivot growth */
  185. /* factor max_j( norm(A_j)/norm(U_j) ). */,
  186. &rcond /*estimate of the reciprocal condition */
  187. /* number of the matrix A after equilibration */,
  188. &ferr[0] /* estimated forward error */,
  189. &berr[0] /* relative backward error */,
  190. &stat, &info, T());
  191. rcond_ = rcond;
  192. Destroy_SuperMatrix_Store(&SB);
  193. Destroy_SuperMatrix_Store(&SX);
  194. Destroy_SuperMatrix_Store(&SA);
  195. Destroy_SuperNode_Matrix(&SL);
  196. Destroy_CompCol_Matrix(&SU);
  197. StatFree(&stat);
  198. GMM_ASSERT1(info >= 0, "SuperLU solve failed: info =" << info);
  199. if (info > 0) GMM_WARNING1("SuperLU solve failed: info =" << info);
  200. gmm::copy(sol, X);
  201. return info;
  202. }
  203. template <class T> class SuperLU_factor {
  204. typedef typename number_traits<T>::magnitude_type R;
  205. csc_matrix<T> csc_A;
  206. mutable SuperMatrix SA, SL, SB, SU, SX;
  207. mutable SuperLUStat_t stat;
  208. mutable superlu_options_t options;
  209. float memory_used;
  210. mutable std::vector<int> etree, perm_r, perm_c;
  211. mutable std::vector<R> Rscale, Cscale;
  212. mutable std::vector<R> ferr, berr;
  213. mutable std::vector<T> rhs;
  214. mutable std::vector<T> sol;
  215. mutable bool is_init;
  216. mutable char equed;
  217. public :
  218. enum { LU_NOTRANSP, LU_TRANSP, LU_CONJUGATED };
  219. void free_supermatrix(void);
  220. template <class MAT> void build_with(const MAT &A, int permc_spec = 3);
  221. template <typename VECTX, typename VECTB>
  222. /* transp = LU_NOTRANSP -> solves Ax = B
  223. transp = LU_TRANSP -> solves A'x = B
  224. transp = LU_CONJUGATED -> solves conj(A)X = B */
  225. void solve(const VECTX &X_, const VECTB &B, int transp=LU_NOTRANSP) const;
  226. SuperLU_factor(void) { is_init = false; }
  227. SuperLU_factor(const SuperLU_factor& other) {
  228. GMM_ASSERT2(!(other.is_init),
  229. "copy of initialized SuperLU_factor is forbidden");
  230. is_init = false;
  231. }
  232. SuperLU_factor& operator=(const SuperLU_factor& other) {
  233. GMM_ASSERT2(!(other.is_init) && !is_init,
  234. "assignment of initialized SuperLU_factor is forbidden");
  235. return *this;
  236. }
  237. ~SuperLU_factor() { free_supermatrix(); }
  238. float memsize() { return memory_used; }
  239. };
  240. template <class T> void SuperLU_factor<T>::free_supermatrix(void) {
  241. if (is_init) {
  242. if (SB.Store) Destroy_SuperMatrix_Store(&SB);
  243. if (SX.Store) Destroy_SuperMatrix_Store(&SX);
  244. if (SA.Store) Destroy_SuperMatrix_Store(&SA);
  245. if (SL.Store) Destroy_SuperNode_Matrix(&SL);
  246. if (SU.Store) Destroy_CompCol_Matrix(&SU);
  247. }
  248. }
  249. template <class T> template <class MAT>
  250. void SuperLU_factor<T>::build_with(const MAT &A, int permc_spec) {
  251. /*
  252. * Get column permutation vector perm_c[], according to permc_spec:
  253. * permc_spec = 0: use the natural ordering
  254. * permc_spec = 1: use minimum degree ordering on structure of A'*A
  255. * permc_spec = 2: use minimum degree ordering on structure of A'+A
  256. * permc_spec = 3: use approximate minimum degree column ordering
  257. */
  258. free_supermatrix();
  259. int n = mat_nrows(A), m = mat_ncols(A), info = 0;
  260. csc_A.init_with(A);
  261. rhs.resize(m); sol.resize(m);
  262. gmm::clear(rhs);
  263. int nz = nnz(csc_A);
  264. set_default_options(&options);
  265. options.ColPerm = NATURAL;
  266. options.PrintStat = NO;
  267. options.ConditionNumber = NO;
  268. switch (permc_spec) {
  269. case 1 : options.ColPerm = MMD_ATA; break;
  270. case 2 : options.ColPerm = MMD_AT_PLUS_A; break;
  271. case 3 : options.ColPerm = COLAMD; break;
  272. }
  273. StatInit(&stat);
  274. Create_CompCol_Matrix(&SA, m, n, nz, (double *)(&(csc_A.pr[0])),
  275. (int *)(&(csc_A.ir[0])), (int *)(&(csc_A.jc[0])));
  276. Create_Dense_Matrix(&SB, m, 0, &rhs[0], m);
  277. Create_Dense_Matrix(&SX, m, 0, &sol[0], m);
  278. memset(&SL,0,sizeof SL);
  279. memset(&SU,0,sizeof SU);
  280. equed = 'B';
  281. Rscale.resize(m); Cscale.resize(n); etree.resize(n);
  282. ferr.resize(1); berr.resize(1);
  283. R recip_pivot_gross, rcond;
  284. perm_r.resize(m); perm_c.resize(n);
  285. memory_used = SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
  286. &etree[0] /* output */, &equed /* output */,
  287. &Rscale[0] /* row scale factors (output) */,
  288. &Cscale[0] /* col scale factors (output) */,
  289. &SL /* fact L (output)*/, &SU /* fact U (output)*/,
  290. NULL /* work */,
  291. 0 /* lwork: superlu auto allocates (input) */,
  292. &SB /* rhs */, &SX /* solution */,
  293. &recip_pivot_gross /* reciprocal pivot growth */
  294. /* factor max_j( norm(A_j)/norm(U_j) ). */,
  295. &rcond /*estimate of the reciprocal condition */
  296. /* number of the matrix A after equilibration */,
  297. &ferr[0] /* estimated forward error */,
  298. &berr[0] /* relative backward error */,
  299. &stat, &info, T());
  300. Destroy_SuperMatrix_Store(&SB);
  301. Destroy_SuperMatrix_Store(&SX);
  302. Create_Dense_Matrix(&SB, m, 1, &rhs[0], m);
  303. Create_Dense_Matrix(&SX, m, 1, &sol[0], m);
  304. StatFree(&stat);
  305. GMM_ASSERT1(info == 0, "SuperLU solve failed: info=" << info);
  306. is_init = true;
  307. }
  308. template <class T> template <typename VECTX, typename VECTB>
  309. void SuperLU_factor<T>::solve(const VECTX &X_, const VECTB &B,
  310. int transp) const {
  311. VECTX &X = const_cast<VECTX &>(X_);
  312. gmm::copy(B, rhs);
  313. options.Fact = FACTORED;
  314. options.IterRefine = NOREFINE;
  315. switch (transp) {
  316. case LU_NOTRANSP: options.Trans = NOTRANS; break;
  317. case LU_TRANSP: options.Trans = TRANS; break;
  318. case LU_CONJUGATED: options.Trans = CONJ; break;
  319. default: GMM_ASSERT1(false, "invalid value for transposition option");
  320. }
  321. StatInit(&stat);
  322. int info = 0;
  323. R recip_pivot_gross, rcond;
  324. SuperLU_gssvx(&options, &SA, &perm_c[0], &perm_r[0],
  325. &etree[0] /* output */, &equed /* output */,
  326. &Rscale[0] /* row scale factors (output) */,
  327. &Cscale[0] /* col scale factors (output) */,
  328. &SL /* fact L (output)*/, &SU /* fact U (output)*/,
  329. NULL /* work */,
  330. 0 /* lwork: superlu auto allocates (input) */,
  331. &SB /* rhs */, &SX /* solution */,
  332. &recip_pivot_gross /* reciprocal pivot growth */
  333. /* factor max_j( norm(A_j)/norm(U_j) ). */,
  334. &rcond /*estimate of the reciprocal condition */
  335. /* number of the matrix A after equilibration */,
  336. &ferr[0] /* estimated forward error */,
  337. &berr[0] /* relative backward error */,
  338. &stat, &info, T());
  339. StatFree(&stat);
  340. GMM_ASSERT1(info == 0, "SuperLU solve failed: info=" << info);
  341. gmm::copy(sol, X);
  342. }
  343. template <typename T, typename V1, typename V2> inline
  344. void mult(const SuperLU_factor<T>& P, const V1 &v1, const V2 &v2) {
  345. P.solve(v2,v1);
  346. }
  347. template <typename T, typename V1, typename V2> inline
  348. void transposed_mult(const SuperLU_factor<T>& P,const V1 &v1,const V2 &v2) {
  349. P.solve(v2, v1, SuperLU_factor<T>::LU_TRANSP);
  350. }
  351. }
  352. #endif // GMM_SUPERLU_INTERFACE_H
  353. #endif // GMM_USES_SUPERLU