You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

249 lines
8.5 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
  5. // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
  6. //
  7. // This Source Code Form is subject to the terms of the Mozilla
  8. // Public License v. 2.0. If a copy of the MPL was not distributed
  9. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  10. #define TEST_ENABLE_TEMPORARY_TRACKING
  11. #define STORMEIGEN_NO_STATIC_ASSERT
  12. #include "main.h"
  13. template<typename ArrayType> void vectorwiseop_array(const ArrayType& m)
  14. {
  15. typedef typename ArrayType::Index Index;
  16. typedef typename ArrayType::Scalar Scalar;
  17. typedef Array<Scalar, ArrayType::RowsAtCompileTime, 1> ColVectorType;
  18. typedef Array<Scalar, 1, ArrayType::ColsAtCompileTime> RowVectorType;
  19. Index rows = m.rows();
  20. Index cols = m.cols();
  21. Index r = internal::random<Index>(0, rows-1),
  22. c = internal::random<Index>(0, cols-1);
  23. ArrayType m1 = ArrayType::Random(rows, cols),
  24. m2(rows, cols),
  25. m3(rows, cols);
  26. ColVectorType colvec = ColVectorType::Random(rows);
  27. RowVectorType rowvec = RowVectorType::Random(cols);
  28. // test addition
  29. m2 = m1;
  30. m2.colwise() += colvec;
  31. VERIFY_IS_APPROX(m2, m1.colwise() + colvec);
  32. VERIFY_IS_APPROX(m2.col(c), m1.col(c) + colvec);
  33. VERIFY_RAISES_ASSERT(m2.colwise() += colvec.transpose());
  34. VERIFY_RAISES_ASSERT(m1.colwise() + colvec.transpose());
  35. m2 = m1;
  36. m2.rowwise() += rowvec;
  37. VERIFY_IS_APPROX(m2, m1.rowwise() + rowvec);
  38. VERIFY_IS_APPROX(m2.row(r), m1.row(r) + rowvec);
  39. VERIFY_RAISES_ASSERT(m2.rowwise() += rowvec.transpose());
  40. VERIFY_RAISES_ASSERT(m1.rowwise() + rowvec.transpose());
  41. // test substraction
  42. m2 = m1;
  43. m2.colwise() -= colvec;
  44. VERIFY_IS_APPROX(m2, m1.colwise() - colvec);
  45. VERIFY_IS_APPROX(m2.col(c), m1.col(c) - colvec);
  46. VERIFY_RAISES_ASSERT(m2.colwise() -= colvec.transpose());
  47. VERIFY_RAISES_ASSERT(m1.colwise() - colvec.transpose());
  48. m2 = m1;
  49. m2.rowwise() -= rowvec;
  50. VERIFY_IS_APPROX(m2, m1.rowwise() - rowvec);
  51. VERIFY_IS_APPROX(m2.row(r), m1.row(r) - rowvec);
  52. VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
  53. VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
  54. // test multiplication
  55. m2 = m1;
  56. m2.colwise() *= colvec;
  57. VERIFY_IS_APPROX(m2, m1.colwise() * colvec);
  58. VERIFY_IS_APPROX(m2.col(c), m1.col(c) * colvec);
  59. VERIFY_RAISES_ASSERT(m2.colwise() *= colvec.transpose());
  60. VERIFY_RAISES_ASSERT(m1.colwise() * colvec.transpose());
  61. m2 = m1;
  62. m2.rowwise() *= rowvec;
  63. VERIFY_IS_APPROX(m2, m1.rowwise() * rowvec);
  64. VERIFY_IS_APPROX(m2.row(r), m1.row(r) * rowvec);
  65. VERIFY_RAISES_ASSERT(m2.rowwise() *= rowvec.transpose());
  66. VERIFY_RAISES_ASSERT(m1.rowwise() * rowvec.transpose());
  67. // test quotient
  68. m2 = m1;
  69. m2.colwise() /= colvec;
  70. VERIFY_IS_APPROX(m2, m1.colwise() / colvec);
  71. VERIFY_IS_APPROX(m2.col(c), m1.col(c) / colvec);
  72. VERIFY_RAISES_ASSERT(m2.colwise() /= colvec.transpose());
  73. VERIFY_RAISES_ASSERT(m1.colwise() / colvec.transpose());
  74. m2 = m1;
  75. m2.rowwise() /= rowvec;
  76. VERIFY_IS_APPROX(m2, m1.rowwise() / rowvec);
  77. VERIFY_IS_APPROX(m2.row(r), m1.row(r) / rowvec);
  78. VERIFY_RAISES_ASSERT(m2.rowwise() /= rowvec.transpose());
  79. VERIFY_RAISES_ASSERT(m1.rowwise() / rowvec.transpose());
  80. m2 = m1;
  81. // yes, there might be an aliasing issue there but ".rowwise() /="
  82. // is supposed to evaluate " m2.colwise().sum()" into a temporary to avoid
  83. // evaluating the reduction multiple times
  84. if(ArrayType::RowsAtCompileTime>2 || ArrayType::RowsAtCompileTime==Dynamic)
  85. {
  86. m2.rowwise() /= m2.colwise().sum();
  87. VERIFY_IS_APPROX(m2, m1.rowwise() / m1.colwise().sum());
  88. }
  89. // all/any
  90. Array<bool,Dynamic,Dynamic> mb(rows,cols);
  91. mb = (m1.real()<=0.7).colwise().all();
  92. VERIFY( (mb.col(c) == (m1.real().col(c)<=0.7).all()).all() );
  93. mb = (m1.real()<=0.7).rowwise().all();
  94. VERIFY( (mb.row(r) == (m1.real().row(r)<=0.7).all()).all() );
  95. mb = (m1.real()>=0.7).colwise().any();
  96. VERIFY( (mb.col(c) == (m1.real().col(c)>=0.7).any()).all() );
  97. mb = (m1.real()>=0.7).rowwise().any();
  98. VERIFY( (mb.row(r) == (m1.real().row(r)>=0.7).any()).all() );
  99. }
  100. template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
  101. {
  102. typedef typename MatrixType::Index Index;
  103. typedef typename MatrixType::Scalar Scalar;
  104. typedef typename NumTraits<Scalar>::Real RealScalar;
  105. typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
  106. typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
  107. typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealColVectorType;
  108. typedef Matrix<RealScalar, 1, MatrixType::ColsAtCompileTime> RealRowVectorType;
  109. Index rows = m.rows();
  110. Index cols = m.cols();
  111. Index r = internal::random<Index>(0, rows-1),
  112. c = internal::random<Index>(0, cols-1);
  113. MatrixType m1 = MatrixType::Random(rows, cols),
  114. m2(rows, cols),
  115. m3(rows, cols);
  116. ColVectorType colvec = ColVectorType::Random(rows);
  117. RowVectorType rowvec = RowVectorType::Random(cols);
  118. RealColVectorType rcres;
  119. RealRowVectorType rrres;
  120. // test addition
  121. m2 = m1;
  122. m2.colwise() += colvec;
  123. VERIFY_IS_APPROX(m2, m1.colwise() + colvec);
  124. VERIFY_IS_APPROX(m2.col(c), m1.col(c) + colvec);
  125. if(rows>1)
  126. {
  127. VERIFY_RAISES_ASSERT(m2.colwise() += colvec.transpose());
  128. VERIFY_RAISES_ASSERT(m1.colwise() + colvec.transpose());
  129. }
  130. m2 = m1;
  131. m2.rowwise() += rowvec;
  132. VERIFY_IS_APPROX(m2, m1.rowwise() + rowvec);
  133. VERIFY_IS_APPROX(m2.row(r), m1.row(r) + rowvec);
  134. if(cols>1)
  135. {
  136. VERIFY_RAISES_ASSERT(m2.rowwise() += rowvec.transpose());
  137. VERIFY_RAISES_ASSERT(m1.rowwise() + rowvec.transpose());
  138. }
  139. // test substraction
  140. m2 = m1;
  141. m2.colwise() -= colvec;
  142. VERIFY_IS_APPROX(m2, m1.colwise() - colvec);
  143. VERIFY_IS_APPROX(m2.col(c), m1.col(c) - colvec);
  144. if(rows>1)
  145. {
  146. VERIFY_RAISES_ASSERT(m2.colwise() -= colvec.transpose());
  147. VERIFY_RAISES_ASSERT(m1.colwise() - colvec.transpose());
  148. }
  149. m2 = m1;
  150. m2.rowwise() -= rowvec;
  151. VERIFY_IS_APPROX(m2, m1.rowwise() - rowvec);
  152. VERIFY_IS_APPROX(m2.row(r), m1.row(r) - rowvec);
  153. if(cols>1)
  154. {
  155. VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
  156. VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
  157. }
  158. // test norm
  159. rrres = m1.colwise().norm();
  160. VERIFY_IS_APPROX(rrres(c), m1.col(c).norm());
  161. rcres = m1.rowwise().norm();
  162. VERIFY_IS_APPROX(rcres(r), m1.row(r).norm());
  163. VERIFY_IS_APPROX(m1.cwiseAbs().colwise().sum(), m1.colwise().template lpNorm<1>());
  164. VERIFY_IS_APPROX(m1.cwiseAbs().rowwise().sum(), m1.rowwise().template lpNorm<1>());
  165. VERIFY_IS_APPROX(m1.cwiseAbs().colwise().maxCoeff(), m1.colwise().template lpNorm<Infinity>());
  166. VERIFY_IS_APPROX(m1.cwiseAbs().rowwise().maxCoeff(), m1.rowwise().template lpNorm<Infinity>());
  167. // test normalized
  168. m2 = m1.colwise().normalized();
  169. VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
  170. m2 = m1.rowwise().normalized();
  171. VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
  172. // test normalize
  173. m2 = m1;
  174. m2.colwise().normalize();
  175. VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
  176. m2 = m1;
  177. m2.rowwise().normalize();
  178. VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
  179. // test with partial reduction of products
  180. Matrix<Scalar,MatrixType::RowsAtCompileTime,MatrixType::RowsAtCompileTime> m1m1 = m1 * m1.transpose();
  181. VERIFY_IS_APPROX( (m1 * m1.transpose()).colwise().sum(), m1m1.colwise().sum());
  182. Matrix<Scalar,1,MatrixType::RowsAtCompileTime> tmp(rows);
  183. VERIFY_EVALUATION_COUNT( tmp = (m1 * m1.transpose()).colwise().sum(), (MatrixType::RowsAtCompileTime==Dynamic ? 1 : 0));
  184. m2 = m1.rowwise() - (m1.colwise().sum()/m1.rows()).eval();
  185. m1 = m1.rowwise() - (m1.colwise().sum()/m1.rows());
  186. VERIFY_IS_APPROX( m1, m2 );
  187. VERIFY_EVALUATION_COUNT( m2 = (m1.rowwise() - m1.colwise().sum()/m1.rows()), (MatrixType::RowsAtCompileTime==Dynamic && MatrixType::ColsAtCompileTime!=1 ? 1 : 0) );
  188. }
  189. void test_vectorwiseop()
  190. {
  191. CALL_SUBTEST_1( vectorwiseop_array(Array22cd()) );
  192. CALL_SUBTEST_2( vectorwiseop_array(Array<double, 3, 2>()) );
  193. CALL_SUBTEST_3( vectorwiseop_array(ArrayXXf(3, 4)) );
  194. CALL_SUBTEST_4( vectorwiseop_matrix(Matrix4cf()) );
  195. CALL_SUBTEST_5( vectorwiseop_matrix(Matrix<float,4,5>()) );
  196. CALL_SUBTEST_6( vectorwiseop_matrix(MatrixXd(internal::random<int>(1,STORMEIGEN_TEST_MAX_SIZE), internal::random<int>(1,STORMEIGEN_TEST_MAX_SIZE))) );
  197. CALL_SUBTEST_7( vectorwiseop_matrix(VectorXd(internal::random<int>(1,STORMEIGEN_TEST_MAX_SIZE))) );
  198. CALL_SUBTEST_7( vectorwiseop_matrix(RowVectorXd(internal::random<int>(1,STORMEIGEN_TEST_MAX_SIZE))) );
  199. }