You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

277 lines
10 KiB

  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2015 Gael Guennebaud <gael.guennebaud@inria.fr>
  5. //
  6. // This Source Code Form is subject to the terms of the Mozilla
  7. // Public License v. 2.0. If a copy of the MPL was not distributed
  8. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  9. #ifndef STORMEIGEN_SPARSE_COMPRESSED_BASE_H
  10. #define STORMEIGEN_SPARSE_COMPRESSED_BASE_H
  11. namespace StormEigen {
  12. template<typename Derived> class SparseCompressedBase;
  13. namespace internal {
  14. template<typename Derived>
  15. struct traits<SparseCompressedBase<Derived> > : traits<Derived>
  16. {};
  17. } // end namespace internal
  18. template<typename Derived>
  19. class SparseCompressedBase
  20. : public SparseMatrixBase<Derived>
  21. {
  22. public:
  23. typedef SparseMatrixBase<Derived> Base;
  24. STORMEIGEN_SPARSE_PUBLIC_INTERFACE(SparseCompressedBase)
  25. using Base::operator=;
  26. using Base::IsRowMajor;
  27. class InnerIterator;
  28. class ReverseInnerIterator;
  29. protected:
  30. typedef typename Base::IndexVector IndexVector;
  31. StormEigen::Map<IndexVector> innerNonZeros() { return StormEigen::Map<IndexVector>(innerNonZeroPtr(), isCompressed()?0:derived().outerSize()); }
  32. const StormEigen::Map<const IndexVector> innerNonZeros() const { return StormEigen::Map<const IndexVector>(innerNonZeroPtr(), isCompressed()?0:derived().outerSize()); }
  33. public:
  34. /** \returns the number of non zero coefficients */
  35. inline Index nonZeros() const
  36. {
  37. if(Derived::IsVectorAtCompileTime && outerIndexPtr()==0)
  38. return derived().nonZeros();
  39. else if(isCompressed())
  40. return outerIndexPtr()[derived().outerSize()]-outerIndexPtr()[0];
  41. else if(derived().outerSize()==0)
  42. return 0;
  43. else
  44. return innerNonZeros().sum();
  45. }
  46. /** \returns a const pointer to the array of values.
  47. * This function is aimed at interoperability with other libraries.
  48. * \sa innerIndexPtr(), outerIndexPtr() */
  49. inline const Scalar* valuePtr() const { return derived().valuePtr(); }
  50. /** \returns a non-const pointer to the array of values.
  51. * This function is aimed at interoperability with other libraries.
  52. * \sa innerIndexPtr(), outerIndexPtr() */
  53. inline Scalar* valuePtr() { return derived().valuePtr(); }
  54. /** \returns a const pointer to the array of inner indices.
  55. * This function is aimed at interoperability with other libraries.
  56. * \sa valuePtr(), outerIndexPtr() */
  57. inline const StorageIndex* innerIndexPtr() const { return derived().innerIndexPtr(); }
  58. /** \returns a non-const pointer to the array of inner indices.
  59. * This function is aimed at interoperability with other libraries.
  60. * \sa valuePtr(), outerIndexPtr() */
  61. inline StorageIndex* innerIndexPtr() { return derived().innerIndexPtr(); }
  62. /** \returns a const pointer to the array of the starting positions of the inner vectors.
  63. * This function is aimed at interoperability with other libraries.
  64. * \warning it returns the null pointer 0 for SparseVector
  65. * \sa valuePtr(), innerIndexPtr() */
  66. inline const StorageIndex* outerIndexPtr() const { return derived().outerIndexPtr(); }
  67. /** \returns a non-const pointer to the array of the starting positions of the inner vectors.
  68. * This function is aimed at interoperability with other libraries.
  69. * \warning it returns the null pointer 0 for SparseVector
  70. * \sa valuePtr(), innerIndexPtr() */
  71. inline StorageIndex* outerIndexPtr() { return derived().outerIndexPtr(); }
  72. /** \returns a const pointer to the array of the number of non zeros of the inner vectors.
  73. * This function is aimed at interoperability with other libraries.
  74. * \warning it returns the null pointer 0 in compressed mode */
  75. inline const StorageIndex* innerNonZeroPtr() const { return derived().innerNonZeroPtr(); }
  76. /** \returns a non-const pointer to the array of the number of non zeros of the inner vectors.
  77. * This function is aimed at interoperability with other libraries.
  78. * \warning it returns the null pointer 0 in compressed mode */
  79. inline StorageIndex* innerNonZeroPtr() { return derived().innerNonZeroPtr(); }
  80. /** \returns whether \c *this is in compressed form. */
  81. inline bool isCompressed() const { return innerNonZeroPtr()==0; }
  82. protected:
  83. /** Default constructor. Do nothing. */
  84. SparseCompressedBase() {}
  85. private:
  86. template<typename OtherDerived> explicit SparseCompressedBase(const SparseCompressedBase<OtherDerived>&);
  87. };
  88. template<typename Derived>
  89. class SparseCompressedBase<Derived>::InnerIterator
  90. {
  91. public:
  92. InnerIterator(const SparseCompressedBase& mat, Index outer)
  93. : m_values(mat.valuePtr()), m_indices(mat.innerIndexPtr()), m_outer(outer)
  94. {
  95. if(Derived::IsVectorAtCompileTime && mat.outerIndexPtr()==0)
  96. {
  97. m_id = 0;
  98. m_end = mat.nonZeros();
  99. }
  100. else
  101. {
  102. m_id = mat.outerIndexPtr()[outer];
  103. if(mat.isCompressed())
  104. m_end = mat.outerIndexPtr()[outer+1];
  105. else
  106. m_end = m_id + mat.innerNonZeroPtr()[outer];
  107. }
  108. }
  109. explicit InnerIterator(const SparseCompressedBase& mat)
  110. : m_values(mat.valuePtr()), m_indices(mat.innerIndexPtr()), m_outer(0), m_id(0), m_end(mat.nonZeros())
  111. {
  112. STORMEIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
  113. }
  114. explicit InnerIterator(const internal::CompressedStorage<Scalar,StorageIndex>& data)
  115. : m_values(&data.value(0)), m_indices(&data.index(0)), m_outer(0), m_id(0), m_end(data.size())
  116. {
  117. STORMEIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
  118. }
  119. inline InnerIterator& operator++() { m_id++; return *this; }
  120. inline const Scalar& value() const { return m_values[m_id]; }
  121. inline Scalar& valueRef() { return const_cast<Scalar&>(m_values[m_id]); }
  122. inline StorageIndex index() const { return m_indices[m_id]; }
  123. inline Index outer() const { return m_outer.value(); }
  124. inline Index row() const { return IsRowMajor ? m_outer.value() : index(); }
  125. inline Index col() const { return IsRowMajor ? index() : m_outer.value(); }
  126. inline operator bool() const { return (m_id < m_end); }
  127. protected:
  128. const Scalar* m_values;
  129. const StorageIndex* m_indices;
  130. const internal::variable_if_dynamic<Index,Derived::IsVectorAtCompileTime?0:Dynamic> m_outer;
  131. Index m_id;
  132. Index m_end;
  133. private:
  134. // If you get here, then you're not using the right InnerIterator type, e.g.:
  135. // SparseMatrix<double,RowMajor> A;
  136. // SparseMatrix<double>::InnerIterator it(A,0);
  137. template<typename T> InnerIterator(const SparseMatrixBase<T>&, Index outer);
  138. };
  139. template<typename Derived>
  140. class SparseCompressedBase<Derived>::ReverseInnerIterator
  141. {
  142. public:
  143. ReverseInnerIterator(const SparseCompressedBase& mat, Index outer)
  144. : m_values(mat.valuePtr()), m_indices(mat.innerIndexPtr()), m_outer(outer)
  145. {
  146. if(Derived::IsVectorAtCompileTime && mat.outerIndexPtr()==0)
  147. {
  148. m_start = 0;
  149. m_id = mat.nonZeros();
  150. }
  151. else
  152. {
  153. m_start.value() = mat.outerIndexPtr()[outer];
  154. if(mat.isCompressed())
  155. m_id = mat.outerIndexPtr()[outer+1];
  156. else
  157. m_id = m_start.value() + mat.innerNonZeroPtr()[outer];
  158. }
  159. }
  160. explicit ReverseInnerIterator(const SparseCompressedBase& mat)
  161. : m_values(mat.valuePtr()), m_indices(mat.innerIndexPtr()), m_outer(0), m_start(0), m_id(mat.nonZeros())
  162. {
  163. STORMEIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
  164. }
  165. explicit ReverseInnerIterator(const internal::CompressedStorage<Scalar,StorageIndex>& data)
  166. : m_values(&data.value(0)), m_indices(&data.index(0)), m_outer(0), m_start(0), m_id(data.size())
  167. {
  168. STORMEIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived);
  169. }
  170. inline ReverseInnerIterator& operator--() { --m_id; return *this; }
  171. inline const Scalar& value() const { return m_values[m_id-1]; }
  172. inline Scalar& valueRef() { return const_cast<Scalar&>(m_values[m_id-1]); }
  173. inline StorageIndex index() const { return m_indices[m_id-1]; }
  174. inline Index outer() const { return m_outer.value(); }
  175. inline Index row() const { return IsRowMajor ? m_outer.value() : index(); }
  176. inline Index col() const { return IsRowMajor ? index() : m_outer.value(); }
  177. inline operator bool() const { return (m_id > m_start.value()); }
  178. protected:
  179. const Scalar* m_values;
  180. const StorageIndex* m_indices;
  181. const internal::variable_if_dynamic<Index,Derived::IsVectorAtCompileTime?0:Dynamic> m_outer;
  182. Index m_id;
  183. const internal::variable_if_dynamic<Index,Derived::IsVectorAtCompileTime?0:Dynamic> m_start;
  184. };
  185. namespace internal {
  186. template<typename Derived>
  187. struct evaluator<SparseCompressedBase<Derived> >
  188. : evaluator_base<Derived>
  189. {
  190. typedef typename Derived::Scalar Scalar;
  191. typedef typename Derived::InnerIterator InnerIterator;
  192. typedef typename Derived::ReverseInnerIterator ReverseInnerIterator;
  193. enum {
  194. CoeffReadCost = NumTraits<Scalar>::ReadCost,
  195. Flags = Derived::Flags
  196. };
  197. evaluator() : m_matrix(0)
  198. {
  199. STORMEIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
  200. }
  201. explicit evaluator(const Derived &mat) : m_matrix(&mat)
  202. {
  203. STORMEIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
  204. }
  205. inline Index nonZerosEstimate() const {
  206. return m_matrix->nonZeros();
  207. }
  208. operator Derived&() { return m_matrix->const_cast_derived(); }
  209. operator const Derived&() const { return *m_matrix; }
  210. typedef typename DenseCoeffsBase<Derived,ReadOnlyAccessors>::CoeffReturnType CoeffReturnType;
  211. Scalar coeff(Index row, Index col) const
  212. { return m_matrix->coeff(row,col); }
  213. Scalar& coeffRef(Index row, Index col)
  214. {
  215. eigen_internal_assert(row>=0 && row<m_matrix->rows() && col>=0 && col<m_matrix->cols());
  216. const Index outer = Derived::IsRowMajor ? row : col;
  217. const Index inner = Derived::IsRowMajor ? col : row;
  218. Index start = m_matrix->outerIndexPtr()[outer];
  219. Index end = m_matrix->isCompressed() ? m_matrix->outerIndexPtr()[outer+1] : m_matrix->outerIndexPtr()[outer] + m_matrix->innerNonZeroPtr()[outer];
  220. eigen_assert(end>start && "you are using a non finalized sparse matrix or written coefficient does not exist");
  221. const Index p = std::lower_bound(m_matrix->innerIndexPtr()+start, m_matrix->innerIndexPtr()+end,inner)
  222. - m_matrix->innerIndexPtr();
  223. eigen_assert((p<end) && (m_matrix->innerIndexPtr()[p]==inner) && "written coefficient does not exist");
  224. return m_matrix->const_cast_derived().valuePtr()[p];
  225. }
  226. const Derived *m_matrix;
  227. };
  228. }
  229. } // end namespace StormEigen
  230. #endif // STORMEIGEN_SPARSE_COMPRESSED_BASE_H