You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

589 lines
28 KiB

8 years ago
  1. /*
  2. pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices
  3. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "numpy.h"
  9. #if defined(__INTEL_COMPILER)
  10. # pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem)
  11. #elif defined(__GNUG__) || defined(__clang__)
  12. # pragma GCC diagnostic push
  13. # pragma GCC diagnostic ignored "-Wconversion"
  14. # pragma GCC diagnostic ignored "-Wdeprecated-declarations"
  15. # if __GNUC__ >= 7
  16. # pragma GCC diagnostic ignored "-Wint-in-bool-context"
  17. # endif
  18. #endif
  19. #include <Eigen/Core>
  20. #include <Eigen/SparseCore>
  21. #if defined(_MSC_VER)
  22. # pragma warning(push)
  23. # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
  24. #endif
  25. // Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit
  26. // move constructors that break things. We could detect this an explicitly copy, but an extra copy
  27. // of matrices seems highly undesirable.
  28. static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7");
  29. NAMESPACE_BEGIN(pybind11)
  30. // Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides:
  31. using EigenDStride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;
  32. template <typename MatrixType> using EigenDRef = Eigen::Ref<MatrixType, 0, EigenDStride>;
  33. template <typename MatrixType> using EigenDMap = Eigen::Map<MatrixType, 0, EigenDStride>;
  34. NAMESPACE_BEGIN(detail)
  35. #if EIGEN_VERSION_AT_LEAST(3,3,0)
  36. using EigenIndex = Eigen::Index;
  37. #else
  38. using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE;
  39. #endif
  40. // Matches Eigen::Map, Eigen::Ref, blocks, etc:
  41. template <typename T> using is_eigen_dense_map = all_of<is_template_base_of<Eigen::DenseBase, T>, std::is_base_of<Eigen::MapBase<T, Eigen::ReadOnlyAccessors>, T>>;
  42. template <typename T> using is_eigen_mutable_map = std::is_base_of<Eigen::MapBase<T, Eigen::WriteAccessors>, T>;
  43. template <typename T> using is_eigen_dense_plain = all_of<negation<is_eigen_dense_map<T>>, is_template_base_of<Eigen::PlainObjectBase, T>>;
  44. template <typename T> using is_eigen_sparse = is_template_base_of<Eigen::SparseMatrixBase, T>;
  45. // Test for objects inheriting from EigenBase<Derived> that aren't captured by the above. This
  46. // basically covers anything that can be assigned to a dense matrix but that don't have a typical
  47. // matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and
  48. // SelfAdjointView fall into this category.
  49. template <typename T> using is_eigen_other = all_of<
  50. is_template_base_of<Eigen::EigenBase, T>,
  51. negation<any_of<is_eigen_dense_map<T>, is_eigen_dense_plain<T>, is_eigen_sparse<T>>>
  52. >;
  53. // Captures numpy/eigen conformability status (returned by EigenProps::conformable()):
  54. template <bool EigenRowMajor> struct EigenConformable {
  55. bool conformable = false;
  56. EigenIndex rows = 0, cols = 0;
  57. EigenDStride stride{0, 0};
  58. EigenConformable(bool fits = false) : conformable{fits} {}
  59. // Matrix type:
  60. EigenConformable(EigenIndex r, EigenIndex c,
  61. EigenIndex rstride, EigenIndex cstride) :
  62. conformable{true}, rows{r}, cols{c},
  63. stride(EigenRowMajor ? rstride : cstride /* outer stride */,
  64. EigenRowMajor ? cstride : rstride /* inner stride */)
  65. {}
  66. // Vector type:
  67. EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride)
  68. : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {}
  69. template <typename props> bool stride_compatible() const {
  70. // To have compatible strides, we need (on both dimensions) one of fully dynamic strides,
  71. // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant)
  72. return
  73. (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() ||
  74. (EigenRowMajor ? cols : rows) == 1) &&
  75. (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() ||
  76. (EigenRowMajor ? rows : cols) == 1);
  77. }
  78. operator bool() const { return conformable; }
  79. };
  80. template <typename Type> struct eigen_extract_stride { using type = Type; };
  81. template <typename PlainObjectType, int MapOptions, typename StrideType>
  82. struct eigen_extract_stride<Eigen::Map<PlainObjectType, MapOptions, StrideType>> { using type = StrideType; };
  83. template <typename PlainObjectType, int Options, typename StrideType>
  84. struct eigen_extract_stride<Eigen::Ref<PlainObjectType, Options, StrideType>> { using type = StrideType; };
  85. // Helper struct for extracting information from an Eigen type
  86. template <typename Type_> struct EigenProps {
  87. using Type = Type_;
  88. using Scalar = typename Type::Scalar;
  89. using StrideType = typename eigen_extract_stride<Type>::type;
  90. static constexpr EigenIndex
  91. rows = Type::RowsAtCompileTime,
  92. cols = Type::ColsAtCompileTime,
  93. size = Type::SizeAtCompileTime;
  94. static constexpr bool
  95. row_major = Type::IsRowMajor,
  96. vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1
  97. fixed_rows = rows != Eigen::Dynamic,
  98. fixed_cols = cols != Eigen::Dynamic,
  99. fixed = size != Eigen::Dynamic, // Fully-fixed size
  100. dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size
  101. template <EigenIndex i, EigenIndex ifzero> using if_zero = std::integral_constant<EigenIndex, i == 0 ? ifzero : i>;
  102. static constexpr EigenIndex inner_stride = if_zero<StrideType::InnerStrideAtCompileTime, 1>::value,
  103. outer_stride = if_zero<StrideType::OuterStrideAtCompileTime,
  104. vector ? size : row_major ? cols : rows>::value;
  105. static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic;
  106. static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1;
  107. static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1;
  108. // Takes an input array and determines whether we can make it fit into the Eigen type. If
  109. // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector
  110. // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type).
  111. static EigenConformable<row_major> conformable(const array &a) {
  112. const auto dims = a.ndim();
  113. if (dims < 1 || dims > 2)
  114. return false;
  115. if (dims == 2) { // Matrix type: require exact match (or dynamic)
  116. EigenIndex
  117. np_rows = a.shape(0),
  118. np_cols = a.shape(1),
  119. np_rstride = a.strides(0) / sizeof(Scalar),
  120. np_cstride = a.strides(1) / sizeof(Scalar);
  121. if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols))
  122. return false;
  123. return {np_rows, np_cols, np_rstride, np_cstride};
  124. }
  125. // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever
  126. // is used, we want the (single) numpy stride value.
  127. const EigenIndex n = a.shape(0),
  128. stride = a.strides(0) / sizeof(Scalar);
  129. if (vector) { // Eigen type is a compile-time vector
  130. if (fixed && size != n)
  131. return false; // Vector size mismatch
  132. return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride};
  133. }
  134. else if (fixed) {
  135. // The type has a fixed size, but is not a vector: abort
  136. return false;
  137. }
  138. else if (fixed_cols) {
  139. // Since this isn't a vector, cols must be != 1. We allow this only if it exactly
  140. // equals the number of elements (rows is Dynamic, and so 1 row is allowed).
  141. if (cols != n) return false;
  142. return {1, n, stride};
  143. }
  144. else {
  145. // Otherwise it's either fully dynamic, or column dynamic; both become a column vector
  146. if (fixed_rows && rows != n) return false;
  147. return {n, 1, stride};
  148. }
  149. }
  150. static PYBIND11_DESCR descriptor() {
  151. constexpr bool show_writeable = is_eigen_dense_map<Type>::value && is_eigen_mutable_map<Type>::value;
  152. constexpr bool show_order = is_eigen_dense_map<Type>::value;
  153. constexpr bool show_c_contiguous = show_order && requires_row_major;
  154. constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major;
  155. return _("numpy.ndarray[") + npy_format_descriptor<Scalar>::name() +
  156. _("[") + _<fixed_rows>(_<(size_t) rows>(), _("m")) +
  157. _(", ") + _<fixed_cols>(_<(size_t) cols>(), _("n")) +
  158. _("]") +
  159. // For a reference type (e.g. Ref<MatrixXd>) we have other constraints that might need to be
  160. // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride
  161. // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output
  162. // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to
  163. // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you
  164. // *gave* a numpy.ndarray of the right type and dimensions.
  165. _<show_writeable>(", flags.writeable", "") +
  166. _<show_c_contiguous>(", flags.c_contiguous", "") +
  167. _<show_f_contiguous>(", flags.f_contiguous", "") +
  168. _("]");
  169. }
  170. };
  171. // Casts an Eigen type to numpy array. If given a base, the numpy array references the src data,
  172. // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
  173. template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
  174. constexpr size_t elem_size = sizeof(typename props::Scalar);
  175. std::vector<size_t> shape, strides;
  176. if (props::vector) {
  177. shape.push_back(src.size());
  178. strides.push_back(elem_size * src.innerStride());
  179. }
  180. else {
  181. shape.push_back(src.rows());
  182. shape.push_back(src.cols());
  183. strides.push_back(elem_size * src.rowStride());
  184. strides.push_back(elem_size * src.colStride());
  185. }
  186. array a(std::move(shape), std::move(strides), src.data(), base);
  187. if (!writeable)
  188. array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
  189. return a.release();
  190. }
  191. // Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that
  192. // reference the Eigen object's data with `base` as the python-registered base class (if omitted,
  193. // the base will be set to None, and lifetime management is up to the caller). The numpy array is
  194. // non-writeable if the given type is const.
  195. template <typename props, typename Type>
  196. handle eigen_ref_array(Type &src, handle parent = none()) {
  197. // none here is to get past array's should-we-copy detection, which currently always
  198. // copies when there is no base. Setting the base to None should be harmless.
  199. return eigen_array_cast<props>(src, parent, !std::is_const<Type>::value);
  200. }
  201. // Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy
  202. // array that references the encapsulated data with a python-side reference to the capsule to tie
  203. // its destruction to that of any dependent python objects. Const-ness is determined by whether or
  204. // not the Type of the pointer given is const.
  205. template <typename props, typename Type, typename = enable_if_t<is_eigen_dense_plain<Type>::value>>
  206. handle eigen_encapsulate(Type *src) {
  207. capsule base(src, [](void *o) { delete static_cast<Type *>(o); });
  208. return eigen_ref_array<props>(*src, base);
  209. }
  210. // Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense
  211. // types.
  212. template<typename Type>
  213. struct type_caster<Type, enable_if_t<is_eigen_dense_plain<Type>::value>> {
  214. using Scalar = typename Type::Scalar;
  215. using props = EigenProps<Type>;
  216. bool load(handle src, bool) {
  217. auto buf = array_t<Scalar>::ensure(src);
  218. if (!buf)
  219. return false;
  220. auto dims = buf.ndim();
  221. if (dims < 1 || dims > 2)
  222. return false;
  223. auto fits = props::conformable(buf);
  224. if (!fits)
  225. return false; // Non-comformable vector/matrix types
  226. value = Eigen::Map<const Type, 0, EigenDStride>(buf.data(), fits.rows, fits.cols, fits.stride);
  227. return true;
  228. }
  229. private:
  230. // Cast implementation
  231. template <typename CType>
  232. static handle cast_impl(CType *src, return_value_policy policy, handle parent) {
  233. switch (policy) {
  234. case return_value_policy::take_ownership:
  235. case return_value_policy::automatic:
  236. return eigen_encapsulate<props>(src);
  237. case return_value_policy::move:
  238. return eigen_encapsulate<props>(new CType(std::move(*src)));
  239. case return_value_policy::copy:
  240. return eigen_array_cast<props>(*src);
  241. case return_value_policy::reference:
  242. case return_value_policy::automatic_reference:
  243. return eigen_ref_array<props>(*src);
  244. case return_value_policy::reference_internal:
  245. return eigen_ref_array<props>(*src, parent);
  246. default:
  247. throw cast_error("unhandled return_value_policy: should not happen!");
  248. };
  249. }
  250. public:
  251. // Normal returned non-reference, non-const value:
  252. static handle cast(Type &&src, return_value_policy /* policy */, handle parent) {
  253. return cast_impl(&src, return_value_policy::move, parent);
  254. }
  255. // If you return a non-reference const, we mark the numpy array readonly:
  256. static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) {
  257. return cast_impl(&src, return_value_policy::move, parent);
  258. }
  259. // lvalue reference return; default (automatic) becomes copy
  260. static handle cast(Type &src, return_value_policy policy, handle parent) {
  261. if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
  262. policy = return_value_policy::copy;
  263. return cast_impl(&src, policy, parent);
  264. }
  265. // const lvalue reference return; default (automatic) becomes copy
  266. static handle cast(const Type &src, return_value_policy policy, handle parent) {
  267. if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
  268. policy = return_value_policy::copy;
  269. return cast(&src, policy, parent);
  270. }
  271. // non-const pointer return
  272. static handle cast(Type *src, return_value_policy policy, handle parent) {
  273. return cast_impl(src, policy, parent);
  274. }
  275. // const pointer return
  276. static handle cast(const Type *src, return_value_policy policy, handle parent) {
  277. return cast_impl(src, policy, parent);
  278. }
  279. static PYBIND11_DESCR name() { return type_descr(props::descriptor()); }
  280. operator Type*() { return &value; }
  281. operator Type&() { return value; }
  282. template <typename T> using cast_op_type = cast_op_type<T>;
  283. private:
  284. Type value;
  285. };
  286. // Eigen Ref/Map classes have slightly different policy requirements, meaning we don't want to force
  287. // `move` when a Ref/Map rvalue is returned; we treat Ref<> sort of like a pointer (we care about
  288. // the underlying data, not the outer shell).
  289. template <typename Return>
  290. struct return_value_policy_override<Return, enable_if_t<is_eigen_dense_map<Return>::value>> {
  291. static return_value_policy policy(return_value_policy p) { return p; }
  292. };
  293. // Base class for casting reference/map/block/etc. objects back to python.
  294. template <typename MapType> struct eigen_map_caster {
  295. private:
  296. using props = EigenProps<MapType>;
  297. public:
  298. // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has
  299. // to stay around), but we'll allow it under the assumption that you know what you're doing (and
  300. // have an appropriate keep_alive in place). We return a numpy array pointing directly at the
  301. // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note
  302. // that this means you need to ensure you don't destroy the object in some other way (e.g. with
  303. // an appropriate keep_alive, or with a reference to a statically allocated matrix).
  304. static handle cast(const MapType &src, return_value_policy policy, handle parent) {
  305. switch (policy) {
  306. case return_value_policy::copy:
  307. return eigen_array_cast<props>(src);
  308. case return_value_policy::reference_internal:
  309. return eigen_array_cast<props>(src, parent, is_eigen_mutable_map<MapType>::value);
  310. case return_value_policy::reference:
  311. case return_value_policy::automatic:
  312. case return_value_policy::automatic_reference:
  313. return eigen_array_cast<props>(src, none(), is_eigen_mutable_map<MapType>::value);
  314. default:
  315. // move, take_ownership don't make any sense for a ref/map:
  316. pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type");
  317. }
  318. }
  319. static PYBIND11_DESCR name() { return props::descriptor(); }
  320. // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
  321. // types but not bound arguments). We still provide them (with an explicitly delete) so that
  322. // you end up here if you try anyway.
  323. bool load(handle, bool) = delete;
  324. operator MapType() = delete;
  325. template <typename> using cast_op_type = MapType;
  326. };
  327. // We can return any map-like object (but can only load Refs, specialized next):
  328. template <typename Type> struct type_caster<Type, enable_if_t<is_eigen_dense_map<Type>::value>>
  329. : eigen_map_caster<Type> {};
  330. // Loader for Ref<...> arguments. See the documentation for info on how to make this work without
  331. // copying (it requires some extra effort in many cases).
  332. template <typename PlainObjectType, typename StrideType>
  333. struct type_caster<
  334. Eigen::Ref<PlainObjectType, 0, StrideType>,
  335. enable_if_t<is_eigen_dense_map<Eigen::Ref<PlainObjectType, 0, StrideType>>::value>
  336. > : public eigen_map_caster<Eigen::Ref<PlainObjectType, 0, StrideType>> {
  337. private:
  338. using Type = Eigen::Ref<PlainObjectType, 0, StrideType>;
  339. using props = EigenProps<Type>;
  340. using Scalar = typename props::Scalar;
  341. using MapType = Eigen::Map<PlainObjectType, 0, StrideType>;
  342. using Array = array_t<Scalar, array::forcecast |
  343. ((props::row_major ? props::inner_stride : props::outer_stride) == 1 ? array::c_style :
  344. (props::row_major ? props::outer_stride : props::inner_stride) == 1 ? array::f_style : 0)>;
  345. static constexpr bool need_writeable = is_eigen_mutable_map<Type>::value;
  346. // Delay construction (these have no default constructor)
  347. std::unique_ptr<MapType> map;
  348. std::unique_ptr<Type> ref;
  349. // Our array. When possible, this is just a numpy array pointing to the source data, but
  350. // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible
  351. // layout, or is an array of a type that needs to be converted). Using a numpy temporary
  352. // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and
  353. // storage order conversion. (Note that we refuse to use this temporary copy when loading an
  354. // argument for a Ref<M> with M non-const, i.e. a read-write reference).
  355. Array copy_or_ref;
  356. public:
  357. bool load(handle src, bool convert) {
  358. // First check whether what we have is already an array of the right type. If not, we can't
  359. // avoid a copy (because the copy is also going to do type conversion).
  360. bool need_copy = !isinstance<Array>(src);
  361. EigenConformable<props::row_major> fits;
  362. if (!need_copy) {
  363. // We don't need a converting copy, but we also need to check whether the strides are
  364. // compatible with the Ref's stride requirements
  365. Array aref = reinterpret_borrow<Array>(src);
  366. if (aref && (!need_writeable || aref.writeable())) {
  367. fits = props::conformable(aref);
  368. if (!fits) return false; // Incompatible dimensions
  369. if (!fits.template stride_compatible<props>())
  370. need_copy = true;
  371. else
  372. copy_or_ref = std::move(aref);
  373. }
  374. else {
  375. need_copy = true;
  376. }
  377. }
  378. if (need_copy) {
  379. // We need to copy: If we need a mutable reference, or we're not supposed to convert
  380. // (either because we're in the no-convert overload pass, or because we're explicitly
  381. // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading.
  382. if (!convert || need_writeable) return false;
  383. Array copy = Array::ensure(src);
  384. if (!copy) return false;
  385. fits = props::conformable(copy);
  386. if (!fits || !fits.template stride_compatible<props>())
  387. return false;
  388. copy_or_ref = std::move(copy);
  389. }
  390. ref.reset();
  391. map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner())));
  392. ref.reset(new Type(*map));
  393. return true;
  394. }
  395. operator Type*() { return ref.get(); }
  396. operator Type&() { return *ref; }
  397. template <typename _T> using cast_op_type = pybind11::detail::cast_op_type<_T>;
  398. private:
  399. template <typename T = Type, enable_if_t<is_eigen_mutable_map<T>::value, int> = 0>
  400. Scalar *data(Array &a) { return a.mutable_data(); }
  401. template <typename T = Type, enable_if_t<!is_eigen_mutable_map<T>::value, int> = 0>
  402. const Scalar *data(Array &a) { return a.data(); }
  403. // Attempt to figure out a constructor of `Stride` that will work.
  404. // If both strides are fixed, use a default constructor:
  405. template <typename S> using stride_ctor_default = bool_constant<
  406. S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
  407. std::is_default_constructible<S>::value>;
  408. // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like
  409. // Eigen::Stride, and use it:
  410. template <typename S> using stride_ctor_dual = bool_constant<
  411. !stride_ctor_default<S>::value && std::is_constructible<S, EigenIndex, EigenIndex>::value>;
  412. // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use
  413. // it (passing whichever stride is dynamic).
  414. template <typename S> using stride_ctor_outer = bool_constant<
  415. !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
  416. S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic &&
  417. std::is_constructible<S, EigenIndex>::value>;
  418. template <typename S> using stride_ctor_inner = bool_constant<
  419. !any_of<stride_ctor_default<S>, stride_ctor_dual<S>>::value &&
  420. S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic &&
  421. std::is_constructible<S, EigenIndex>::value>;
  422. template <typename S = StrideType, enable_if_t<stride_ctor_default<S>::value, int> = 0>
  423. static S make_stride(EigenIndex, EigenIndex) { return S(); }
  424. template <typename S = StrideType, enable_if_t<stride_ctor_dual<S>::value, int> = 0>
  425. static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); }
  426. template <typename S = StrideType, enable_if_t<stride_ctor_outer<S>::value, int> = 0>
  427. static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); }
  428. template <typename S = StrideType, enable_if_t<stride_ctor_inner<S>::value, int> = 0>
  429. static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); }
  430. };
  431. // type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not
  432. // EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout).
  433. // load() is not supported, but we can cast them into the python domain by first copying to a
  434. // regular Eigen::Matrix, then casting that.
  435. template <typename Type>
  436. struct type_caster<Type, enable_if_t<is_eigen_other<Type>::value>> {
  437. protected:
  438. using Matrix = Eigen::Matrix<typename Type::Scalar, Type::RowsAtCompileTime, Type::ColsAtCompileTime>;
  439. using props = EigenProps<Matrix>;
  440. public:
  441. static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
  442. handle h = eigen_encapsulate<props>(new Matrix(src));
  443. return h;
  444. }
  445. static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); }
  446. static PYBIND11_DESCR name() { return props::descriptor(); }
  447. // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return
  448. // types but not bound arguments). We still provide them (with an explicitly delete) so that
  449. // you end up here if you try anyway.
  450. bool load(handle, bool) = delete;
  451. operator Type() = delete;
  452. template <typename> using cast_op_type = Type;
  453. };
  454. template<typename Type>
  455. struct type_caster<Type, enable_if_t<is_eigen_sparse<Type>::value>> {
  456. typedef typename Type::Scalar Scalar;
  457. typedef typename std::remove_reference<decltype(*std::declval<Type>().outerIndexPtr())>::type StorageIndex;
  458. typedef typename Type::Index Index;
  459. static constexpr bool rowMajor = Type::IsRowMajor;
  460. bool load(handle src, bool) {
  461. if (!src)
  462. return false;
  463. auto obj = reinterpret_borrow<object>(src);
  464. object sparse_module = module::import("scipy.sparse");
  465. object matrix_type = sparse_module.attr(
  466. rowMajor ? "csr_matrix" : "csc_matrix");
  467. if (obj.get_type() != matrix_type.ptr()) {
  468. try {
  469. obj = matrix_type(obj);
  470. } catch (const error_already_set &) {
  471. return false;
  472. }
  473. }
  474. auto values = array_t<Scalar>((object) obj.attr("data"));
  475. auto innerIndices = array_t<StorageIndex>((object) obj.attr("indices"));
  476. auto outerIndices = array_t<StorageIndex>((object) obj.attr("indptr"));
  477. auto shape = pybind11::tuple((pybind11::object) obj.attr("shape"));
  478. auto nnz = obj.attr("nnz").cast<Index>();
  479. if (!values || !innerIndices || !outerIndices)
  480. return false;
  481. value = Eigen::MappedSparseMatrix<Scalar, Type::Flags, StorageIndex>(
  482. shape[0].cast<Index>(), shape[1].cast<Index>(), nnz,
  483. outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data());
  484. return true;
  485. }
  486. static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) {
  487. const_cast<Type&>(src).makeCompressed();
  488. object matrix_type = module::import("scipy.sparse").attr(
  489. rowMajor ? "csr_matrix" : "csc_matrix");
  490. array data((size_t) src.nonZeros(), src.valuePtr());
  491. array outerIndices((size_t) (rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr());
  492. array innerIndices((size_t) src.nonZeros(), src.innerIndexPtr());
  493. return matrix_type(
  494. std::make_tuple(data, innerIndices, outerIndices),
  495. std::make_pair(src.rows(), src.cols())
  496. ).release();
  497. }
  498. PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[")
  499. + npy_format_descriptor<Scalar>::name() + _("]"));
  500. };
  501. NAMESPACE_END(detail)
  502. NAMESPACE_END(pybind11)
  503. #if defined(__GNUG__) || defined(__clang__)
  504. # pragma GCC diagnostic pop
  505. #elif defined(_MSC_VER)
  506. # pragma warning(pop)
  507. #endif