You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
4.0 KiB

  1. //=====================================================
  2. // File : blitz_interface.hh
  3. // Author : L. Plagne <laurent.plagne@edf.fr)>
  4. // Copyright (C) EDF R&D, lun sep 30 14:23:30 CEST 2002
  5. // Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
  6. //=====================================================
  7. //
  8. // This program is free software; you can redistribute it and/or
  9. // modify it under the terms of the GNU General Public License
  10. // as published by the Free Software Foundation; either version 2
  11. // of the License, or (at your option) any later version.
  12. //
  13. // This program is distributed in the hope that it will be useful,
  14. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  15. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  16. // GNU General Public License for more details.
  17. // You should have received a copy of the GNU General Public License
  18. // along with this program; if not, write to the Free Software
  19. // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
  20. //
  21. #ifndef BLITZ_INTERFACE_HH
  22. #define BLITZ_INTERFACE_HH
  23. #include <blitz/blitz.h>
  24. #include <blitz/array.h>
  25. #include <blitz/vector-et.h>
  26. #include <blitz/vecwhere.h>
  27. #include <blitz/matrix.h>
  28. #include <vector>
  29. BZ_USING_NAMESPACE(blitz)
  30. template<class real>
  31. class blitz_interface{
  32. public :
  33. typedef real real_type ;
  34. typedef std::vector<real> stl_vector;
  35. typedef std::vector<stl_vector > stl_matrix;
  36. typedef blitz::Array<real, 2> gene_matrix;
  37. typedef blitz::Array<real, 1> gene_vector;
  38. // typedef blitz::Matrix<real, blitz::ColumnMajor> gene_matrix;
  39. // typedef blitz::Vector<real> gene_vector;
  40. static inline std::string name() { return "blitz"; }
  41. static void free_matrix(gene_matrix & A, int N){}
  42. static void free_vector(gene_vector & B){}
  43. static inline void matrix_from_stl(gene_matrix & A, stl_matrix & A_stl){
  44. A.resize(A_stl[0].size(),A_stl.size());
  45. for (int j=0; j<A_stl.size() ; j++){
  46. for (int i=0; i<A_stl[j].size() ; i++){
  47. A(i,j)=A_stl[j][i];
  48. }
  49. }
  50. }
  51. static inline void vector_from_stl(gene_vector & B, stl_vector & B_stl){
  52. B.resize(B_stl.size());
  53. for (int i=0; i<B_stl.size() ; i++){
  54. B(i)=B_stl[i];
  55. }
  56. }
  57. static inline void vector_to_stl(gene_vector & B, stl_vector & B_stl){
  58. for (int i=0; i<B_stl.size() ; i++){
  59. B_stl[i]=B(i);
  60. }
  61. }
  62. static inline void matrix_to_stl(gene_matrix & A, stl_matrix & A_stl){
  63. int N=A_stl.size();
  64. for (int j=0;j<N;j++){
  65. A_stl[j].resize(N);
  66. for (int i=0;i<N;i++)
  67. A_stl[j][i] = A(i,j);
  68. }
  69. }
  70. static inline void matrix_matrix_product(const gene_matrix & A, const gene_matrix & B, gene_matrix & X, int N)
  71. {
  72. firstIndex i;
  73. secondIndex j;
  74. thirdIndex k;
  75. X = sum(A(i,k) * B(k,j), k);
  76. }
  77. static inline void ata_product(const gene_matrix & A, gene_matrix & X, int N)
  78. {
  79. firstIndex i;
  80. secondIndex j;
  81. thirdIndex k;
  82. X = sum(A(k,i) * A(k,j), k);
  83. }
  84. static inline void aat_product(const gene_matrix & A, gene_matrix & X, int N)
  85. {
  86. firstIndex i;
  87. secondIndex j;
  88. thirdIndex k;
  89. X = sum(A(i,k) * A(j,k), k);
  90. }
  91. static inline void matrix_vector_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N)
  92. {
  93. firstIndex i;
  94. secondIndex j;
  95. X = sum(A(i,j)*B(j),j);
  96. }
  97. static inline void atv_product(gene_matrix & A, gene_vector & B, gene_vector & X, int N)
  98. {
  99. firstIndex i;
  100. secondIndex j;
  101. X = sum(A(j,i) * B(j),j);
  102. }
  103. static inline void axpy(const real coef, const gene_vector & X, gene_vector & Y, int N)
  104. {
  105. firstIndex i;
  106. Y = Y(i) + coef * X(i);
  107. //Y += coef * X;
  108. }
  109. static inline void copy_matrix(const gene_matrix & source, gene_matrix & cible, int N){
  110. cible = source;
  111. //cible.template operator=<gene_matrix>(source);
  112. // for (int i=0;i<N;i++){
  113. // for (int j=0;j<N;j++){
  114. // cible(i,j)=source(i,j);
  115. // }
  116. // }
  117. }
  118. static inline void copy_vector(const gene_vector & source, gene_vector & cible, int N){
  119. //cible.template operator=<gene_vector>(source);
  120. cible = source;
  121. }
  122. };
  123. #endif