You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
6.4 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2011 Benoit Jacob <jacob.benoit.1@gmail.com>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #define EIGEN_NO_STATIC_ASSERT
  10. #include "main.h"
  11. template<typename ArrayType> void vectorwiseop_array(const ArrayType& m)
  12. {
  13. typedef typename ArrayType::Index Index;
  14. typedef typename ArrayType::Scalar Scalar;
  15. typedef Array<Scalar, ArrayType::RowsAtCompileTime, 1> ColVectorType;
  16. typedef Array<Scalar, 1, ArrayType::ColsAtCompileTime> RowVectorType;
  17. Index rows = m.rows();
  18. Index cols = m.cols();
  19. Index r = internal::random<Index>(0, rows-1),
  20. c = internal::random<Index>(0, cols-1);
  21. ArrayType m1 = ArrayType::Random(rows, cols),
  22. m2(rows, cols),
  23. m3(rows, cols);
  24. ColVectorType colvec = ColVectorType::Random(rows);
  25. RowVectorType rowvec = RowVectorType::Random(cols);
  26. // test addition
  27. m2 = m1;
  28. m2.colwise() += colvec;
  29. VERIFY_IS_APPROX(m2, m1.colwise() + colvec);
  30. VERIFY_IS_APPROX(m2.col(c), m1.col(c) + colvec);
  31. VERIFY_RAISES_ASSERT(m2.colwise() += colvec.transpose());
  32. VERIFY_RAISES_ASSERT(m1.colwise() + colvec.transpose());
  33. m2 = m1;
  34. m2.rowwise() += rowvec;
  35. VERIFY_IS_APPROX(m2, m1.rowwise() + rowvec);
  36. VERIFY_IS_APPROX(m2.row(r), m1.row(r) + rowvec);
  37. VERIFY_RAISES_ASSERT(m2.rowwise() += rowvec.transpose());
  38. VERIFY_RAISES_ASSERT(m1.rowwise() + rowvec.transpose());
  39. // test substraction
  40. m2 = m1;
  41. m2.colwise() -= colvec;
  42. VERIFY_IS_APPROX(m2, m1.colwise() - colvec);
  43. VERIFY_IS_APPROX(m2.col(c), m1.col(c) - colvec);
  44. VERIFY_RAISES_ASSERT(m2.colwise() -= colvec.transpose());
  45. VERIFY_RAISES_ASSERT(m1.colwise() - colvec.transpose());
  46. m2 = m1;
  47. m2.rowwise() -= rowvec;
  48. VERIFY_IS_APPROX(m2, m1.rowwise() - rowvec);
  49. VERIFY_IS_APPROX(m2.row(r), m1.row(r) - rowvec);
  50. VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
  51. VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
  52. // test multiplication
  53. m2 = m1;
  54. m2.colwise() *= colvec;
  55. VERIFY_IS_APPROX(m2, m1.colwise() * colvec);
  56. VERIFY_IS_APPROX(m2.col(c), m1.col(c) * colvec);
  57. VERIFY_RAISES_ASSERT(m2.colwise() *= colvec.transpose());
  58. VERIFY_RAISES_ASSERT(m1.colwise() * colvec.transpose());
  59. m2 = m1;
  60. m2.rowwise() *= rowvec;
  61. VERIFY_IS_APPROX(m2, m1.rowwise() * rowvec);
  62. VERIFY_IS_APPROX(m2.row(r), m1.row(r) * rowvec);
  63. VERIFY_RAISES_ASSERT(m2.rowwise() *= rowvec.transpose());
  64. VERIFY_RAISES_ASSERT(m1.rowwise() * rowvec.transpose());
  65. // test quotient
  66. m2 = m1;
  67. m2.colwise() /= colvec;
  68. VERIFY_IS_APPROX(m2, m1.colwise() / colvec);
  69. VERIFY_IS_APPROX(m2.col(c), m1.col(c) / colvec);
  70. VERIFY_RAISES_ASSERT(m2.colwise() /= colvec.transpose());
  71. VERIFY_RAISES_ASSERT(m1.colwise() / colvec.transpose());
  72. m2 = m1;
  73. m2.rowwise() /= rowvec;
  74. VERIFY_IS_APPROX(m2, m1.rowwise() / rowvec);
  75. VERIFY_IS_APPROX(m2.row(r), m1.row(r) / rowvec);
  76. VERIFY_RAISES_ASSERT(m2.rowwise() /= rowvec.transpose());
  77. VERIFY_RAISES_ASSERT(m1.rowwise() / rowvec.transpose());
  78. m2 = m1;
  79. // yes, there might be an aliasing issue there but ".rowwise() /="
  80. // is suppposed to evaluate " m2.colwise().sum()" into to temporary to avoid
  81. // evaluating the reducions multiple times
  82. if(ArrayType::RowsAtCompileTime>2 || ArrayType::RowsAtCompileTime==Dynamic)
  83. {
  84. m2.rowwise() /= m2.colwise().sum();
  85. VERIFY_IS_APPROX(m2, m1.rowwise() / m1.colwise().sum());
  86. }
  87. }
  88. template<typename MatrixType> void vectorwiseop_matrix(const MatrixType& m)
  89. {
  90. typedef typename MatrixType::Index Index;
  91. typedef typename MatrixType::Scalar Scalar;
  92. typedef typename NumTraits<Scalar>::Real RealScalar;
  93. typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
  94. typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
  95. typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealColVectorType;
  96. typedef Matrix<RealScalar, 1, MatrixType::ColsAtCompileTime> RealRowVectorType;
  97. Index rows = m.rows();
  98. Index cols = m.cols();
  99. Index r = internal::random<Index>(0, rows-1),
  100. c = internal::random<Index>(0, cols-1);
  101. MatrixType m1 = MatrixType::Random(rows, cols),
  102. m2(rows, cols),
  103. m3(rows, cols);
  104. ColVectorType colvec = ColVectorType::Random(rows);
  105. RowVectorType rowvec = RowVectorType::Random(cols);
  106. RealColVectorType rcres;
  107. RealRowVectorType rrres;
  108. // test addition
  109. m2 = m1;
  110. m2.colwise() += colvec;
  111. VERIFY_IS_APPROX(m2, m1.colwise() + colvec);
  112. VERIFY_IS_APPROX(m2.col(c), m1.col(c) + colvec);
  113. VERIFY_RAISES_ASSERT(m2.colwise() += colvec.transpose());
  114. VERIFY_RAISES_ASSERT(m1.colwise() + colvec.transpose());
  115. m2 = m1;
  116. m2.rowwise() += rowvec;
  117. VERIFY_IS_APPROX(m2, m1.rowwise() + rowvec);
  118. VERIFY_IS_APPROX(m2.row(r), m1.row(r) + rowvec);
  119. VERIFY_RAISES_ASSERT(m2.rowwise() += rowvec.transpose());
  120. VERIFY_RAISES_ASSERT(m1.rowwise() + rowvec.transpose());
  121. // test substraction
  122. m2 = m1;
  123. m2.colwise() -= colvec;
  124. VERIFY_IS_APPROX(m2, m1.colwise() - colvec);
  125. VERIFY_IS_APPROX(m2.col(c), m1.col(c) - colvec);
  126. VERIFY_RAISES_ASSERT(m2.colwise() -= colvec.transpose());
  127. VERIFY_RAISES_ASSERT(m1.colwise() - colvec.transpose());
  128. m2 = m1;
  129. m2.rowwise() -= rowvec;
  130. VERIFY_IS_APPROX(m2, m1.rowwise() - rowvec);
  131. VERIFY_IS_APPROX(m2.row(r), m1.row(r) - rowvec);
  132. VERIFY_RAISES_ASSERT(m2.rowwise() -= rowvec.transpose());
  133. VERIFY_RAISES_ASSERT(m1.rowwise() - rowvec.transpose());
  134. // test norm
  135. rrres = m1.colwise().norm();
  136. VERIFY_IS_APPROX(rrres(c), m1.col(c).norm());
  137. rcres = m1.rowwise().norm();
  138. VERIFY_IS_APPROX(rcres(r), m1.row(r).norm());
  139. // test normalized
  140. m2 = m1.colwise().normalized();
  141. VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
  142. m2 = m1.rowwise().normalized();
  143. VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
  144. // test normalize
  145. m2 = m1;
  146. m2.colwise().normalize();
  147. VERIFY_IS_APPROX(m2.col(c), m1.col(c).normalized());
  148. m2 = m1;
  149. m2.rowwise().normalize();
  150. VERIFY_IS_APPROX(m2.row(r), m1.row(r).normalized());
  151. }
  152. void test_vectorwiseop()
  153. {
  154. CALL_SUBTEST_1(vectorwiseop_array(Array22cd()));
  155. CALL_SUBTEST_2(vectorwiseop_array(Array<double, 3, 2>()));
  156. CALL_SUBTEST_3(vectorwiseop_array(ArrayXXf(3, 4)));
  157. CALL_SUBTEST_4(vectorwiseop_matrix(Matrix4cf()));
  158. CALL_SUBTEST_5(vectorwiseop_matrix(Matrix<float,4,5>()));
  159. CALL_SUBTEST_6(vectorwiseop_matrix(MatrixXd(7,2)));
  160. }