You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1393 lines
54 KiB

8 years ago
  1. /*
  2. pybind11/numpy.h: Basic NumPy support, vectorize() wrapper
  3. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "pybind11.h"
  9. #include "complex.h"
  10. #include <numeric>
  11. #include <algorithm>
  12. #include <array>
  13. #include <cstdlib>
  14. #include <cstring>
  15. #include <sstream>
  16. #include <string>
  17. #include <initializer_list>
  18. #include <functional>
  19. #include <utility>
  20. #include <typeindex>
  21. #if defined(_MSC_VER)
  22. # pragma warning(push)
  23. # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
  24. #endif
  25. /* This will be true on all flat address space platforms and allows us to reduce the
  26. whole npy_intp / size_t / Py_intptr_t business down to just size_t for all size
  27. and dimension types (e.g. shape, strides, indexing), instead of inflicting this
  28. upon the library user. */
  29. static_assert(sizeof(size_t) == sizeof(Py_intptr_t), "size_t != Py_intptr_t");
  30. NAMESPACE_BEGIN(pybind11)
  31. class array; // Forward declaration
  32. NAMESPACE_BEGIN(detail)
  33. template <typename type, typename SFINAE = void> struct npy_format_descriptor;
  34. struct PyArrayDescr_Proxy {
  35. PyObject_HEAD
  36. PyObject *typeobj;
  37. char kind;
  38. char type;
  39. char byteorder;
  40. char flags;
  41. int type_num;
  42. int elsize;
  43. int alignment;
  44. char *subarray;
  45. PyObject *fields;
  46. PyObject *names;
  47. };
  48. struct PyArray_Proxy {
  49. PyObject_HEAD
  50. char *data;
  51. int nd;
  52. ssize_t *dimensions;
  53. ssize_t *strides;
  54. PyObject *base;
  55. PyObject *descr;
  56. int flags;
  57. };
  58. struct PyVoidScalarObject_Proxy {
  59. PyObject_VAR_HEAD
  60. char *obval;
  61. PyArrayDescr_Proxy *descr;
  62. int flags;
  63. PyObject *base;
  64. };
  65. struct numpy_type_info {
  66. PyObject* dtype_ptr;
  67. std::string format_str;
  68. };
  69. struct numpy_internals {
  70. std::unordered_map<std::type_index, numpy_type_info> registered_dtypes;
  71. numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) {
  72. auto it = registered_dtypes.find(std::type_index(tinfo));
  73. if (it != registered_dtypes.end())
  74. return &(it->second);
  75. if (throw_if_missing)
  76. pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name());
  77. return nullptr;
  78. }
  79. template<typename T> numpy_type_info *get_type_info(bool throw_if_missing = true) {
  80. return get_type_info(typeid(typename std::remove_cv<T>::type), throw_if_missing);
  81. }
  82. };
  83. inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) {
  84. ptr = &get_or_create_shared_data<numpy_internals>("_numpy_internals");
  85. }
  86. inline numpy_internals& get_numpy_internals() {
  87. static numpy_internals* ptr = nullptr;
  88. if (!ptr)
  89. load_numpy_internals(ptr);
  90. return *ptr;
  91. }
  92. struct npy_api {
  93. enum constants {
  94. NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
  95. NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
  96. NPY_ARRAY_OWNDATA_ = 0x0004,
  97. NPY_ARRAY_FORCECAST_ = 0x0010,
  98. NPY_ARRAY_ENSUREARRAY_ = 0x0040,
  99. NPY_ARRAY_ALIGNED_ = 0x0100,
  100. NPY_ARRAY_WRITEABLE_ = 0x0400,
  101. NPY_BOOL_ = 0,
  102. NPY_BYTE_, NPY_UBYTE_,
  103. NPY_SHORT_, NPY_USHORT_,
  104. NPY_INT_, NPY_UINT_,
  105. NPY_LONG_, NPY_ULONG_,
  106. NPY_LONGLONG_, NPY_ULONGLONG_,
  107. NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_,
  108. NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_,
  109. NPY_OBJECT_ = 17,
  110. NPY_STRING_, NPY_UNICODE_, NPY_VOID_
  111. };
  112. static npy_api& get() {
  113. static npy_api api = lookup();
  114. return api;
  115. }
  116. bool PyArray_Check_(PyObject *obj) const {
  117. return (bool) PyObject_TypeCheck(obj, PyArray_Type_);
  118. }
  119. bool PyArrayDescr_Check_(PyObject *obj) const {
  120. return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_);
  121. }
  122. PyObject *(*PyArray_DescrFromType_)(int);
  123. PyObject *(*PyArray_NewFromDescr_)
  124. (PyTypeObject *, PyObject *, int, Py_intptr_t *,
  125. Py_intptr_t *, void *, int, PyObject *);
  126. PyObject *(*PyArray_DescrNewFromType_)(int);
  127. PyObject *(*PyArray_NewCopy_)(PyObject *, int);
  128. PyTypeObject *PyArray_Type_;
  129. PyTypeObject *PyVoidArrType_Type_;
  130. PyTypeObject *PyArrayDescr_Type_;
  131. PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
  132. PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *);
  133. int (*PyArray_DescrConverter_) (PyObject *, PyObject **);
  134. bool (*PyArray_EquivTypes_) (PyObject *, PyObject *);
  135. int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *,
  136. Py_ssize_t *, PyObject **, PyObject *);
  137. PyObject *(*PyArray_Squeeze_)(PyObject *);
  138. int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
  139. private:
  140. enum functions {
  141. API_PyArray_Type = 2,
  142. API_PyArrayDescr_Type = 3,
  143. API_PyVoidArrType_Type = 39,
  144. API_PyArray_DescrFromType = 45,
  145. API_PyArray_DescrFromScalar = 57,
  146. API_PyArray_FromAny = 69,
  147. API_PyArray_NewCopy = 85,
  148. API_PyArray_NewFromDescr = 94,
  149. API_PyArray_DescrNewFromType = 9,
  150. API_PyArray_DescrConverter = 174,
  151. API_PyArray_EquivTypes = 182,
  152. API_PyArray_GetArrayParamsFromObject = 278,
  153. API_PyArray_Squeeze = 136,
  154. API_PyArray_SetBaseObject = 282
  155. };
  156. static npy_api lookup() {
  157. module m = module::import("numpy.core.multiarray");
  158. auto c = m.attr("_ARRAY_API");
  159. #if PY_MAJOR_VERSION >= 3
  160. void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL);
  161. #else
  162. void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr());
  163. #endif
  164. npy_api api;
  165. #define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
  166. DECL_NPY_API(PyArray_Type);
  167. DECL_NPY_API(PyVoidArrType_Type);
  168. DECL_NPY_API(PyArrayDescr_Type);
  169. DECL_NPY_API(PyArray_DescrFromType);
  170. DECL_NPY_API(PyArray_DescrFromScalar);
  171. DECL_NPY_API(PyArray_FromAny);
  172. DECL_NPY_API(PyArray_NewCopy);
  173. DECL_NPY_API(PyArray_NewFromDescr);
  174. DECL_NPY_API(PyArray_DescrNewFromType);
  175. DECL_NPY_API(PyArray_DescrConverter);
  176. DECL_NPY_API(PyArray_EquivTypes);
  177. DECL_NPY_API(PyArray_GetArrayParamsFromObject);
  178. DECL_NPY_API(PyArray_Squeeze);
  179. DECL_NPY_API(PyArray_SetBaseObject);
  180. #undef DECL_NPY_API
  181. return api;
  182. }
  183. };
  184. inline PyArray_Proxy* array_proxy(void* ptr) {
  185. return reinterpret_cast<PyArray_Proxy*>(ptr);
  186. }
  187. inline const PyArray_Proxy* array_proxy(const void* ptr) {
  188. return reinterpret_cast<const PyArray_Proxy*>(ptr);
  189. }
  190. inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) {
  191. return reinterpret_cast<PyArrayDescr_Proxy*>(ptr);
  192. }
  193. inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) {
  194. return reinterpret_cast<const PyArrayDescr_Proxy*>(ptr);
  195. }
  196. inline bool check_flags(const void* ptr, int flag) {
  197. return (flag == (array_proxy(ptr)->flags & flag));
  198. }
  199. template <typename T> struct is_std_array : std::false_type { };
  200. template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type { };
  201. template <typename T> struct is_complex : std::false_type { };
  202. template <typename T> struct is_complex<std::complex<T>> : std::true_type { };
  203. template <typename T> using is_pod_struct = all_of<
  204. std::is_pod<T>, // since we're accessing directly in memory we need a POD type
  205. satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
  206. >;
  207. template <size_t Dim = 0, typename Strides> size_t byte_offset_unsafe(const Strides &) { return 0; }
  208. template <size_t Dim = 0, typename Strides, typename... Ix>
  209. size_t byte_offset_unsafe(const Strides &strides, size_t i, Ix... index) {
  210. return i * strides[Dim] + byte_offset_unsafe<Dim + 1>(strides, index...);
  211. }
  212. /** Proxy class providing unsafe, unchecked const access to array data. This is constructed through
  213. * the `unchecked<T, N>()` method of `array` or the `unchecked<N>()` method of `array_t<T>`. `Dims`
  214. * will be -1 for dimensions determined at runtime.
  215. */
  216. template <typename T, ssize_t Dims>
  217. class unchecked_reference {
  218. protected:
  219. static constexpr bool Dynamic = Dims < 0;
  220. const unsigned char *data_;
  221. // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to
  222. // make large performance gains on big, nested loops, but requires compile-time dimensions
  223. conditional_t<Dynamic, const size_t *, std::array<size_t, (size_t) Dims>>
  224. shape_, strides_;
  225. const size_t dims_;
  226. friend class pybind11::array;
  227. // Constructor for compile-time dimensions:
  228. template <bool Dyn = Dynamic>
  229. unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<!Dyn, size_t>)
  230. : data_{reinterpret_cast<const unsigned char *>(data)}, dims_{Dims} {
  231. for (size_t i = 0; i < dims_; i++) {
  232. shape_[i] = shape[i];
  233. strides_[i] = strides[i];
  234. }
  235. }
  236. // Constructor for runtime dimensions:
  237. template <bool Dyn = Dynamic>
  238. unchecked_reference(const void *data, const size_t *shape, const size_t *strides, enable_if_t<Dyn, size_t> dims)
  239. : data_{reinterpret_cast<const unsigned char *>(data)}, shape_{shape}, strides_{strides}, dims_{dims} {}
  240. public:
  241. /** Unchecked const reference access to data at the given indices. For a compile-time known
  242. * number of dimensions, this requires the correct number of arguments; for run-time
  243. * dimensionality, this is not checked (and so is up to the caller to use safely).
  244. */
  245. template <typename... Ix> const T &operator()(Ix... index) const {
  246. static_assert(sizeof...(Ix) == Dims || Dynamic,
  247. "Invalid number of indices for unchecked array reference");
  248. return *reinterpret_cast<const T *>(data_ + byte_offset_unsafe(strides_, size_t(index)...));
  249. }
  250. /** Unchecked const reference access to data; this operator only participates if the reference
  251. * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`.
  252. */
  253. template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  254. const T &operator[](size_t index) const { return operator()(index); }
  255. /// Pointer access to the data at the given indices.
  256. template <typename... Ix> const T *data(Ix... ix) const { return &operator()(size_t(ix)...); }
  257. /// Returns the item size, i.e. sizeof(T)
  258. constexpr static size_t itemsize() { return sizeof(T); }
  259. /// Returns the shape (i.e. size) of dimension `dim`
  260. size_t shape(size_t dim) const { return shape_[dim]; }
  261. /// Returns the number of dimensions of the array
  262. size_t ndim() const { return dims_; }
  263. /// Returns the total number of elements in the referenced array, i.e. the product of the shapes
  264. template <bool Dyn = Dynamic>
  265. enable_if_t<!Dyn, size_t> size() const {
  266. return std::accumulate(shape_.begin(), shape_.end(), (size_t) 1, std::multiplies<size_t>());
  267. }
  268. template <bool Dyn = Dynamic>
  269. enable_if_t<Dyn, size_t> size() const {
  270. return std::accumulate(shape_, shape_ + ndim(), (size_t) 1, std::multiplies<size_t>());
  271. }
  272. /// Returns the total number of bytes used by the referenced data. Note that the actual span in
  273. /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice).
  274. size_t nbytes() const {
  275. return size() * itemsize();
  276. }
  277. };
  278. template <typename T, ssize_t Dims>
  279. class unchecked_mutable_reference : public unchecked_reference<T, Dims> {
  280. friend class pybind11::array;
  281. using ConstBase = unchecked_reference<T, Dims>;
  282. using ConstBase::ConstBase;
  283. using ConstBase::Dynamic;
  284. public:
  285. /// Mutable, unchecked access to data at the given indices.
  286. template <typename... Ix> T& operator()(Ix... index) {
  287. static_assert(sizeof...(Ix) == Dims || Dynamic,
  288. "Invalid number of indices for unchecked array reference");
  289. return const_cast<T &>(ConstBase::operator()(index...));
  290. }
  291. /** Mutable, unchecked access data at the given index; this operator only participates if the
  292. * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is
  293. * exactly equivalent to `obj(index)`.
  294. */
  295. template <size_t D = Dims, typename = enable_if_t<D == 1 || Dynamic>>
  296. T &operator[](size_t index) { return operator()(index); }
  297. /// Mutable pointer access to the data at the given indices.
  298. template <typename... Ix> T *mutable_data(Ix... ix) { return &operator()(size_t(ix)...); }
  299. };
  300. template <typename T, size_t Dim>
  301. struct type_caster<unchecked_reference<T, Dim>> {
  302. static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable");
  303. };
  304. template <typename T, size_t Dim>
  305. struct type_caster<unchecked_mutable_reference<T, Dim>> : type_caster<unchecked_reference<T, Dim>> {};
  306. NAMESPACE_END(detail)
  307. class dtype : public object {
  308. public:
  309. PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_);
  310. explicit dtype(const buffer_info &info) {
  311. dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format)));
  312. // If info.itemsize == 0, use the value calculated from the format string
  313. m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr();
  314. }
  315. explicit dtype(const std::string &format) {
  316. m_ptr = from_args(pybind11::str(format)).release().ptr();
  317. }
  318. dtype(const char *format) : dtype(std::string(format)) { }
  319. dtype(list names, list formats, list offsets, size_t itemsize) {
  320. dict args;
  321. args["names"] = names;
  322. args["formats"] = formats;
  323. args["offsets"] = offsets;
  324. args["itemsize"] = pybind11::int_(itemsize);
  325. m_ptr = from_args(args).release().ptr();
  326. }
  327. /// This is essentially the same as calling numpy.dtype(args) in Python.
  328. static dtype from_args(object args) {
  329. PyObject *ptr = nullptr;
  330. if (!detail::npy_api::get().PyArray_DescrConverter_(args.release().ptr(), &ptr) || !ptr)
  331. throw error_already_set();
  332. return reinterpret_steal<dtype>(ptr);
  333. }
  334. /// Return dtype associated with a C++ type.
  335. template <typename T> static dtype of() {
  336. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
  337. }
  338. /// Size of the data type in bytes.
  339. size_t itemsize() const {
  340. return (size_t) detail::array_descriptor_proxy(m_ptr)->elsize;
  341. }
  342. /// Returns true for structured data types.
  343. bool has_fields() const {
  344. return detail::array_descriptor_proxy(m_ptr)->names != nullptr;
  345. }
  346. /// Single-character type code.
  347. char kind() const {
  348. return detail::array_descriptor_proxy(m_ptr)->kind;
  349. }
  350. private:
  351. static object _dtype_from_pep3118() {
  352. static PyObject *obj = module::import("numpy.core._internal")
  353. .attr("_dtype_from_pep3118").cast<object>().release().ptr();
  354. return reinterpret_borrow<object>(obj);
  355. }
  356. dtype strip_padding(size_t itemsize) {
  357. // Recursively strip all void fields with empty names that are generated for
  358. // padding fields (as of NumPy v1.11).
  359. if (!has_fields())
  360. return *this;
  361. struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; };
  362. std::vector<field_descr> field_descriptors;
  363. for (auto field : attr("fields").attr("items")()) {
  364. auto spec = field.cast<tuple>();
  365. auto name = spec[0].cast<pybind11::str>();
  366. auto format = spec[1].cast<tuple>()[0].cast<dtype>();
  367. auto offset = spec[1].cast<tuple>()[1].cast<pybind11::int_>();
  368. if (!len(name) && format.kind() == 'V')
  369. continue;
  370. field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset});
  371. }
  372. std::sort(field_descriptors.begin(), field_descriptors.end(),
  373. [](const field_descr& a, const field_descr& b) {
  374. return a.offset.cast<int>() < b.offset.cast<int>();
  375. });
  376. list names, formats, offsets;
  377. for (auto& descr : field_descriptors) {
  378. names.append(descr.name);
  379. formats.append(descr.format);
  380. offsets.append(descr.offset);
  381. }
  382. return dtype(names, formats, offsets, itemsize);
  383. }
  384. };
  385. class array : public buffer {
  386. public:
  387. PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array)
  388. enum {
  389. c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_,
  390. f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_,
  391. forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_
  392. };
  393. array() : array(0, static_cast<const double *>(nullptr)) {}
  394. array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
  395. const std::vector<size_t> &strides, const void *ptr = nullptr,
  396. handle base = handle()) {
  397. auto& api = detail::npy_api::get();
  398. auto ndim = shape.size();
  399. if (shape.size() != strides.size())
  400. pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
  401. auto descr = dt;
  402. int flags = 0;
  403. if (base && ptr) {
  404. if (isinstance<array>(base))
  405. /* Copy flags from base (except ownership bit) */
  406. flags = reinterpret_borrow<array>(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_;
  407. else
  408. /* Writable by default, easy to downgrade later on if needed */
  409. flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
  410. }
  411. auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
  412. api.PyArray_Type_, descr.release().ptr(), (int) ndim,
  413. reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
  414. reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
  415. const_cast<void *>(ptr), flags, nullptr));
  416. if (!tmp)
  417. pybind11_fail("NumPy: unable to create array!");
  418. if (ptr) {
  419. if (base) {
  420. api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr());
  421. } else {
  422. tmp = reinterpret_steal<object>(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */));
  423. }
  424. }
  425. m_ptr = tmp.release().ptr();
  426. }
  427. array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
  428. const void *ptr = nullptr, handle base = handle())
  429. : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
  430. array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
  431. handle base = handle())
  432. : array(dt, std::vector<size_t>{ count }, ptr, base) { }
  433. template<typename T> array(const std::vector<size_t>& shape,
  434. const std::vector<size_t>& strides,
  435. const T* ptr, handle base = handle())
  436. : array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
  437. template <typename T>
  438. array(const std::vector<size_t> &shape, const T *ptr,
  439. handle base = handle())
  440. : array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
  441. template <typename T>
  442. array(size_t count, const T *ptr, handle base = handle())
  443. : array(std::vector<size_t>{ count }, ptr, base) { }
  444. explicit array(const buffer_info &info)
  445. : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
  446. /// Array descriptor (dtype)
  447. pybind11::dtype dtype() const {
  448. return reinterpret_borrow<pybind11::dtype>(detail::array_proxy(m_ptr)->descr);
  449. }
  450. /// Total number of elements
  451. size_t size() const {
  452. return std::accumulate(shape(), shape() + ndim(), (size_t) 1, std::multiplies<size_t>());
  453. }
  454. /// Byte size of a single element
  455. size_t itemsize() const {
  456. return (size_t) detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
  457. }
  458. /// Total number of bytes
  459. size_t nbytes() const {
  460. return size() * itemsize();
  461. }
  462. /// Number of dimensions
  463. size_t ndim() const {
  464. return (size_t) detail::array_proxy(m_ptr)->nd;
  465. }
  466. /// Base object
  467. object base() const {
  468. return reinterpret_borrow<object>(detail::array_proxy(m_ptr)->base);
  469. }
  470. /// Dimensions of the array
  471. const size_t* shape() const {
  472. return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->dimensions);
  473. }
  474. /// Dimension along a given axis
  475. size_t shape(size_t dim) const {
  476. if (dim >= ndim())
  477. fail_dim_check(dim, "invalid axis");
  478. return shape()[dim];
  479. }
  480. /// Strides of the array
  481. const size_t* strides() const {
  482. return reinterpret_cast<const size_t *>(detail::array_proxy(m_ptr)->strides);
  483. }
  484. /// Stride along a given axis
  485. size_t strides(size_t dim) const {
  486. if (dim >= ndim())
  487. fail_dim_check(dim, "invalid axis");
  488. return strides()[dim];
  489. }
  490. /// Return the NumPy array flags
  491. int flags() const {
  492. return detail::array_proxy(m_ptr)->flags;
  493. }
  494. /// If set, the array is writeable (otherwise the buffer is read-only)
  495. bool writeable() const {
  496. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_);
  497. }
  498. /// If set, the array owns the data (will be freed when the array is deleted)
  499. bool owndata() const {
  500. return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_);
  501. }
  502. /// Pointer to the contained data. If index is not provided, points to the
  503. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  504. template<typename... Ix> const void* data(Ix... index) const {
  505. return static_cast<const void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  506. }
  507. /// Mutable pointer to the contained data. If index is not provided, points to the
  508. /// beginning of the buffer. May throw if the index would lead to out of bounds access.
  509. /// May throw if the array is not writeable.
  510. template<typename... Ix> void* mutable_data(Ix... index) {
  511. check_writeable();
  512. return static_cast<void *>(detail::array_proxy(m_ptr)->data + offset_at(index...));
  513. }
  514. /// Byte offset from beginning of the array to a given index (full or partial).
  515. /// May throw if the index would lead to out of bounds access.
  516. template<typename... Ix> size_t offset_at(Ix... index) const {
  517. if (sizeof...(index) > ndim())
  518. fail_dim_check(sizeof...(index), "too many indices for an array");
  519. return byte_offset(size_t(index)...);
  520. }
  521. size_t offset_at() const { return 0; }
  522. /// Item count from beginning of the array to a given index (full or partial).
  523. /// May throw if the index would lead to out of bounds access.
  524. template<typename... Ix> size_t index_at(Ix... index) const {
  525. return offset_at(index...) / itemsize();
  526. }
  527. /** Returns a proxy object that provides access to the array's data without bounds or
  528. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  529. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  530. * and the caller must take care not to access invalid dimensions or dimension indices.
  531. */
  532. template <typename T, ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
  533. if (Dims >= 0 && ndim() != (size_t) Dims)
  534. throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
  535. "; expected " + std::to_string(Dims));
  536. return detail::unchecked_mutable_reference<T, Dims>(mutable_data(), shape(), strides(), ndim());
  537. }
  538. /** Returns a proxy object that provides const access to the array's data without bounds or
  539. * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the
  540. * underlying array have the `writable` flag. Use with care: the array must not be destroyed or
  541. * reshaped for the duration of the returned object, and the caller must take care not to access
  542. * invalid dimensions or dimension indices.
  543. */
  544. template <typename T, ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
  545. if (Dims >= 0 && ndim() != (size_t) Dims)
  546. throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) +
  547. "; expected " + std::to_string(Dims));
  548. return detail::unchecked_reference<T, Dims>(data(), shape(), strides(), ndim());
  549. }
  550. /// Return a new view with all of the dimensions of length 1 removed
  551. array squeeze() {
  552. auto& api = detail::npy_api::get();
  553. return reinterpret_steal<array>(api.PyArray_Squeeze_(m_ptr));
  554. }
  555. /// Ensure that the argument is a NumPy array
  556. /// In case of an error, nullptr is returned and the Python error is cleared.
  557. static array ensure(handle h, int ExtraFlags = 0) {
  558. auto result = reinterpret_steal<array>(raw_array(h.ptr(), ExtraFlags));
  559. if (!result)
  560. PyErr_Clear();
  561. return result;
  562. }
  563. protected:
  564. template<typename, typename> friend struct detail::npy_format_descriptor;
  565. void fail_dim_check(size_t dim, const std::string& msg) const {
  566. throw index_error(msg + ": " + std::to_string(dim) +
  567. " (ndim = " + std::to_string(ndim()) + ")");
  568. }
  569. template<typename... Ix> size_t byte_offset(Ix... index) const {
  570. check_dimensions(index...);
  571. return detail::byte_offset_unsafe(strides(), size_t(index)...);
  572. }
  573. void check_writeable() const {
  574. if (!writeable())
  575. throw std::domain_error("array is not writeable");
  576. }
  577. static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
  578. auto ndim = shape.size();
  579. std::vector<size_t> strides(ndim);
  580. if (ndim) {
  581. std::fill(strides.begin(), strides.end(), itemsize);
  582. for (size_t i = 0; i < ndim - 1; i++)
  583. for (size_t j = 0; j < ndim - 1 - i; j++)
  584. strides[j] *= shape[ndim - 1 - i];
  585. }
  586. return strides;
  587. }
  588. template<typename... Ix> void check_dimensions(Ix... index) const {
  589. check_dimensions_impl(size_t(0), shape(), size_t(index)...);
  590. }
  591. void check_dimensions_impl(size_t, const size_t*) const { }
  592. template<typename... Ix> void check_dimensions_impl(size_t axis, const size_t* shape, size_t i, Ix... index) const {
  593. if (i >= *shape) {
  594. throw index_error(std::string("index ") + std::to_string(i) +
  595. " is out of bounds for axis " + std::to_string(axis) +
  596. " with size " + std::to_string(*shape));
  597. }
  598. check_dimensions_impl(axis + 1, shape + 1, index...);
  599. }
  600. /// Create array from any object -- always returns a new reference
  601. static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) {
  602. if (ptr == nullptr)
  603. return nullptr;
  604. return detail::npy_api::get().PyArray_FromAny_(
  605. ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
  606. }
  607. };
  608. template <typename T, int ExtraFlags = array::forcecast> class array_t : public array {
  609. public:
  610. using value_type = T;
  611. array_t() : array(0, static_cast<const T *>(nullptr)) {}
  612. array_t(handle h, borrowed_t) : array(h, borrowed) { }
  613. array_t(handle h, stolen_t) : array(h, stolen) { }
  614. PYBIND11_DEPRECATED("Use array_t<T>::ensure() instead")
  615. array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen) {
  616. if (!m_ptr) PyErr_Clear();
  617. if (!is_borrowed) Py_XDECREF(h.ptr());
  618. }
  619. array_t(const object &o) : array(raw_array_t(o.ptr()), stolen) {
  620. if (!m_ptr) throw error_already_set();
  621. }
  622. explicit array_t(const buffer_info& info) : array(info) { }
  623. array_t(const std::vector<size_t> &shape,
  624. const std::vector<size_t> &strides, const T *ptr = nullptr,
  625. handle base = handle())
  626. : array(shape, strides, ptr, base) { }
  627. explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
  628. handle base = handle())
  629. : array(shape, ptr, base) { }
  630. explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
  631. : array(count, ptr, base) { }
  632. constexpr size_t itemsize() const {
  633. return sizeof(T);
  634. }
  635. template<typename... Ix> size_t index_at(Ix... index) const {
  636. return offset_at(index...) / itemsize();
  637. }
  638. template<typename... Ix> const T* data(Ix... index) const {
  639. return static_cast<const T*>(array::data(index...));
  640. }
  641. template<typename... Ix> T* mutable_data(Ix... index) {
  642. return static_cast<T*>(array::mutable_data(index...));
  643. }
  644. // Reference to element at a given index
  645. template<typename... Ix> const T& at(Ix... index) const {
  646. if (sizeof...(index) != ndim())
  647. fail_dim_check(sizeof...(index), "index dimension mismatch");
  648. return *(static_cast<const T*>(array::data()) + byte_offset(size_t(index)...) / itemsize());
  649. }
  650. // Mutable reference to element at a given index
  651. template<typename... Ix> T& mutable_at(Ix... index) {
  652. if (sizeof...(index) != ndim())
  653. fail_dim_check(sizeof...(index), "index dimension mismatch");
  654. return *(static_cast<T*>(array::mutable_data()) + byte_offset(size_t(index)...) / itemsize());
  655. }
  656. /** Returns a proxy object that provides access to the array's data without bounds or
  657. * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with
  658. * care: the array must not be destroyed or reshaped for the duration of the returned object,
  659. * and the caller must take care not to access invalid dimensions or dimension indices.
  660. */
  661. template <ssize_t Dims = -1> detail::unchecked_mutable_reference<T, Dims> mutable_unchecked() {
  662. return array::mutable_unchecked<T, Dims>();
  663. }
  664. /** Returns a proxy object that provides const access to the array's data without bounds or
  665. * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying
  666. * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped
  667. * for the duration of the returned object, and the caller must take care not to access invalid
  668. * dimensions or dimension indices.
  669. */
  670. template <ssize_t Dims = -1> detail::unchecked_reference<T, Dims> unchecked() const {
  671. return array::unchecked<T, Dims>();
  672. }
  673. /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert
  674. /// it). In case of an error, nullptr is returned and the Python error is cleared.
  675. static array_t ensure(handle h) {
  676. auto result = reinterpret_steal<array_t>(raw_array_t(h.ptr()));
  677. if (!result)
  678. PyErr_Clear();
  679. return result;
  680. }
  681. static bool check_(handle h) {
  682. const auto &api = detail::npy_api::get();
  683. return api.PyArray_Check_(h.ptr())
  684. && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of<T>().ptr());
  685. }
  686. protected:
  687. /// Create array from any object -- always returns a new reference
  688. static PyObject *raw_array_t(PyObject *ptr) {
  689. if (ptr == nullptr)
  690. return nullptr;
  691. return detail::npy_api::get().PyArray_FromAny_(
  692. ptr, dtype::of<T>().release().ptr(), 0, 0,
  693. detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr);
  694. }
  695. };
  696. template <typename T>
  697. struct format_descriptor<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  698. static std::string format() {
  699. return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::format();
  700. }
  701. };
  702. template <size_t N> struct format_descriptor<char[N]> {
  703. static std::string format() { return std::to_string(N) + "s"; }
  704. };
  705. template <size_t N> struct format_descriptor<std::array<char, N>> {
  706. static std::string format() { return std::to_string(N) + "s"; }
  707. };
  708. template <typename T>
  709. struct format_descriptor<T, detail::enable_if_t<std::is_enum<T>::value>> {
  710. static std::string format() {
  711. return format_descriptor<
  712. typename std::remove_cv<typename std::underlying_type<T>::type>::type>::format();
  713. }
  714. };
  715. NAMESPACE_BEGIN(detail)
  716. template <typename T, int ExtraFlags>
  717. struct pyobject_caster<array_t<T, ExtraFlags>> {
  718. using type = array_t<T, ExtraFlags>;
  719. bool load(handle src, bool convert) {
  720. if (!convert && !type::check_(src))
  721. return false;
  722. value = type::ensure(src);
  723. return static_cast<bool>(value);
  724. }
  725. static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) {
  726. return src.inc_ref();
  727. }
  728. PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
  729. };
  730. template <typename T>
  731. struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::value>> {
  732. static bool compare(const buffer_info& b) {
  733. return npy_api::get().PyArray_EquivTypes_(dtype::of<T>().ptr(), dtype(b).ptr());
  734. }
  735. };
  736. template <typename T> struct npy_format_descriptor<T, enable_if_t<satisfies_any_of<T, std::is_arithmetic, is_complex>::value>> {
  737. private:
  738. // NB: the order here must match the one in common.h
  739. constexpr static const int values[15] = {
  740. npy_api::NPY_BOOL_,
  741. npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_SHORT_, npy_api::NPY_USHORT_,
  742. npy_api::NPY_INT_, npy_api::NPY_UINT_, npy_api::NPY_LONGLONG_, npy_api::NPY_ULONGLONG_,
  743. npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_,
  744. npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_
  745. };
  746. public:
  747. static constexpr int value = values[detail::is_fmt_numeric<T>::index];
  748. static pybind11::dtype dtype() {
  749. if (auto ptr = npy_api::get().PyArray_DescrFromType_(value))
  750. return reinterpret_borrow<pybind11::dtype>(ptr);
  751. pybind11_fail("Unsupported buffer format!");
  752. }
  753. template <typename T2 = T, enable_if_t<std::is_integral<T2>::value, int> = 0>
  754. static PYBIND11_DESCR name() {
  755. return _<std::is_same<T, bool>::value>(_("bool"),
  756. _<std::is_signed<T>::value>("int", "uint") + _<sizeof(T)*8>());
  757. }
  758. template <typename T2 = T, enable_if_t<std::is_floating_point<T2>::value, int> = 0>
  759. static PYBIND11_DESCR name() {
  760. return _<std::is_same<T, float>::value || std::is_same<T, double>::value>(
  761. _("float") + _<sizeof(T)*8>(), _("longdouble"));
  762. }
  763. template <typename T2 = T, enable_if_t<is_complex<T2>::value, int> = 0>
  764. static PYBIND11_DESCR name() {
  765. return _<std::is_same<typename T2::value_type, float>::value || std::is_same<typename T2::value_type, double>::value>(
  766. _("complex") + _<sizeof(typename T2::value_type)*16>(), _("longcomplex"));
  767. }
  768. };
  769. #define PYBIND11_DECL_CHAR_FMT \
  770. static PYBIND11_DESCR name() { return _("S") + _<N>(); } \
  771. static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); }
  772. template <size_t N> struct npy_format_descriptor<char[N]> { PYBIND11_DECL_CHAR_FMT };
  773. template <size_t N> struct npy_format_descriptor<std::array<char, N>> { PYBIND11_DECL_CHAR_FMT };
  774. #undef PYBIND11_DECL_CHAR_FMT
  775. template<typename T> struct npy_format_descriptor<T, enable_if_t<std::is_enum<T>::value>> {
  776. private:
  777. using base_descr = npy_format_descriptor<typename std::underlying_type<T>::type>;
  778. public:
  779. static PYBIND11_DESCR name() { return base_descr::name(); }
  780. static pybind11::dtype dtype() { return base_descr::dtype(); }
  781. };
  782. struct field_descriptor {
  783. const char *name;
  784. size_t offset;
  785. size_t size;
  786. size_t alignment;
  787. std::string format;
  788. dtype descr;
  789. };
  790. inline PYBIND11_NOINLINE void register_structured_dtype(
  791. const std::initializer_list<field_descriptor>& fields,
  792. const std::type_info& tinfo, size_t itemsize,
  793. bool (*direct_converter)(PyObject *, void *&)) {
  794. auto& numpy_internals = get_numpy_internals();
  795. if (numpy_internals.get_type_info(tinfo, false))
  796. pybind11_fail("NumPy: dtype is already registered");
  797. list names, formats, offsets;
  798. for (auto field : fields) {
  799. if (!field.descr)
  800. pybind11_fail(std::string("NumPy: unsupported field dtype: `") +
  801. field.name + "` @ " + tinfo.name());
  802. names.append(PYBIND11_STR_TYPE(field.name));
  803. formats.append(field.descr);
  804. offsets.append(pybind11::int_(field.offset));
  805. }
  806. auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr();
  807. // There is an existing bug in NumPy (as of v1.11): trailing bytes are
  808. // not encoded explicitly into the format string. This will supposedly
  809. // get fixed in v1.12; for further details, see these:
  810. // - https://github.com/numpy/numpy/issues/7797
  811. // - https://github.com/numpy/numpy/pull/7798
  812. // Because of this, we won't use numpy's logic to generate buffer format
  813. // strings and will just do it ourselves.
  814. std::vector<field_descriptor> ordered_fields(fields);
  815. std::sort(ordered_fields.begin(), ordered_fields.end(),
  816. [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
  817. size_t offset = 0;
  818. std::ostringstream oss;
  819. oss << "T{";
  820. for (auto& field : ordered_fields) {
  821. if (field.offset > offset)
  822. oss << (field.offset - offset) << 'x';
  823. // mark unaligned fields with '^' (unaligned native type)
  824. if (field.offset % field.alignment)
  825. oss << '^';
  826. oss << field.format << ':' << field.name << ':';
  827. offset = field.offset + field.size;
  828. }
  829. if (itemsize > offset)
  830. oss << (itemsize - offset) << 'x';
  831. oss << '}';
  832. auto format_str = oss.str();
  833. // Sanity check: verify that NumPy properly parses our buffer format string
  834. auto& api = npy_api::get();
  835. auto arr = array(buffer_info(nullptr, itemsize, format_str, 1));
  836. if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr()))
  837. pybind11_fail("NumPy: invalid buffer descriptor!");
  838. auto tindex = std::type_index(tinfo);
  839. numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str };
  840. get_internals().direct_conversions[tindex].push_back(direct_converter);
  841. }
  842. template <typename T, typename SFINAE> struct npy_format_descriptor {
  843. static_assert(is_pod_struct<T>::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype");
  844. static PYBIND11_DESCR name() { return make_caster<T>::name(); }
  845. static pybind11::dtype dtype() {
  846. return reinterpret_borrow<pybind11::dtype>(dtype_ptr());
  847. }
  848. static std::string format() {
  849. static auto format_str = get_numpy_internals().get_type_info<T>(true)->format_str;
  850. return format_str;
  851. }
  852. static void register_dtype(const std::initializer_list<field_descriptor>& fields) {
  853. register_structured_dtype(fields, typeid(typename std::remove_cv<T>::type),
  854. sizeof(T), &direct_converter);
  855. }
  856. private:
  857. static PyObject* dtype_ptr() {
  858. static PyObject* ptr = get_numpy_internals().get_type_info<T>(true)->dtype_ptr;
  859. return ptr;
  860. }
  861. static bool direct_converter(PyObject *obj, void*& value) {
  862. auto& api = npy_api::get();
  863. if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_))
  864. return false;
  865. if (auto descr = reinterpret_steal<object>(api.PyArray_DescrFromScalar_(obj))) {
  866. if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) {
  867. value = ((PyVoidScalarObject_Proxy *) obj)->obval;
  868. return true;
  869. }
  870. }
  871. return false;
  872. }
  873. };
  874. #define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
  875. ::pybind11::detail::field_descriptor { \
  876. Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
  877. alignof(decltype(std::declval<T>().Field)), \
  878. ::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
  879. ::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
  880. }
  881. // Extract name, offset and format descriptor for a struct field
  882. #define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field)
  883. // The main idea of this macro is borrowed from https://github.com/swansontec/map-macro
  884. // (C) William Swanson, Paul Fultz
  885. #define PYBIND11_EVAL0(...) __VA_ARGS__
  886. #define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__)))
  887. #define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__)))
  888. #define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__)))
  889. #define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__)))
  890. #define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__)))
  891. #define PYBIND11_MAP_END(...)
  892. #define PYBIND11_MAP_OUT
  893. #define PYBIND11_MAP_COMMA ,
  894. #define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END
  895. #define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT
  896. #define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0)
  897. #define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next)
  898. #ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround
  899. #define PYBIND11_MAP_LIST_NEXT1(test, next) \
  900. PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
  901. #else
  902. #define PYBIND11_MAP_LIST_NEXT1(test, next) \
  903. PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
  904. #endif
  905. #define PYBIND11_MAP_LIST_NEXT(test, next) \
  906. PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
  907. #define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \
  908. f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__)
  909. #define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \
  910. f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__)
  911. // PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ...
  912. #define PYBIND11_MAP_LIST(f, t, ...) \
  913. PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0))
  914. #define PYBIND11_NUMPY_DTYPE(Type, ...) \
  915. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
  916. ({PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)})
  917. #ifdef _MSC_VER
  918. #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  919. PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0))
  920. #else
  921. #define PYBIND11_MAP2_LIST_NEXT1(test, next) \
  922. PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)
  923. #endif
  924. #define PYBIND11_MAP2_LIST_NEXT(test, next) \
  925. PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next)
  926. #define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \
  927. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__)
  928. #define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \
  929. f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__)
  930. // PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ...
  931. #define PYBIND11_MAP2_LIST(f, t, ...) \
  932. PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0))
  933. #define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \
  934. ::pybind11::detail::npy_format_descriptor<Type>::register_dtype \
  935. ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)})
  936. template <class T>
  937. using array_iterator = typename std::add_pointer<T>::type;
  938. template <class T>
  939. array_iterator<T> array_begin(const buffer_info& buffer) {
  940. return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr));
  941. }
  942. template <class T>
  943. array_iterator<T> array_end(const buffer_info& buffer) {
  944. return array_iterator<T>(reinterpret_cast<T*>(buffer.ptr) + buffer.size);
  945. }
  946. class common_iterator {
  947. public:
  948. using container_type = std::vector<size_t>;
  949. using value_type = container_type::value_type;
  950. using size_type = container_type::size_type;
  951. common_iterator() : p_ptr(0), m_strides() {}
  952. common_iterator(void* ptr, const container_type& strides, const std::vector<size_t>& shape)
  953. : p_ptr(reinterpret_cast<char*>(ptr)), m_strides(strides.size()) {
  954. m_strides.back() = static_cast<value_type>(strides.back());
  955. for (size_type i = m_strides.size() - 1; i != 0; --i) {
  956. size_type j = i - 1;
  957. value_type s = static_cast<value_type>(shape[i]);
  958. m_strides[j] = strides[j] + m_strides[i] - strides[i] * s;
  959. }
  960. }
  961. void increment(size_type dim) {
  962. p_ptr += m_strides[dim];
  963. }
  964. void* data() const {
  965. return p_ptr;
  966. }
  967. private:
  968. char* p_ptr;
  969. container_type m_strides;
  970. };
  971. template <size_t N> class multi_array_iterator {
  972. public:
  973. using container_type = std::vector<size_t>;
  974. multi_array_iterator(const std::array<buffer_info, N> &buffers,
  975. const std::vector<size_t> &shape)
  976. : m_shape(shape.size()), m_index(shape.size(), 0),
  977. m_common_iterator() {
  978. // Manual copy to avoid conversion warning if using std::copy
  979. for (size_t i = 0; i < shape.size(); ++i)
  980. m_shape[i] = static_cast<container_type::value_type>(shape[i]);
  981. container_type strides(shape.size());
  982. for (size_t i = 0; i < N; ++i)
  983. init_common_iterator(buffers[i], shape, m_common_iterator[i], strides);
  984. }
  985. multi_array_iterator& operator++() {
  986. for (size_t j = m_index.size(); j != 0; --j) {
  987. size_t i = j - 1;
  988. if (++m_index[i] != m_shape[i]) {
  989. increment_common_iterator(i);
  990. break;
  991. } else {
  992. m_index[i] = 0;
  993. }
  994. }
  995. return *this;
  996. }
  997. template <size_t K, class T> const T& data() const {
  998. return *reinterpret_cast<T*>(m_common_iterator[K].data());
  999. }
  1000. private:
  1001. using common_iter = common_iterator;
  1002. void init_common_iterator(const buffer_info &buffer,
  1003. const std::vector<size_t> &shape,
  1004. common_iter &iterator, container_type &strides) {
  1005. auto buffer_shape_iter = buffer.shape.rbegin();
  1006. auto buffer_strides_iter = buffer.strides.rbegin();
  1007. auto shape_iter = shape.rbegin();
  1008. auto strides_iter = strides.rbegin();
  1009. while (buffer_shape_iter != buffer.shape.rend()) {
  1010. if (*shape_iter == *buffer_shape_iter)
  1011. *strides_iter = static_cast<size_t>(*buffer_strides_iter);
  1012. else
  1013. *strides_iter = 0;
  1014. ++buffer_shape_iter;
  1015. ++buffer_strides_iter;
  1016. ++shape_iter;
  1017. ++strides_iter;
  1018. }
  1019. std::fill(strides_iter, strides.rend(), 0);
  1020. iterator = common_iter(buffer.ptr, strides, shape);
  1021. }
  1022. void increment_common_iterator(size_t dim) {
  1023. for (auto &iter : m_common_iterator)
  1024. iter.increment(dim);
  1025. }
  1026. container_type m_shape;
  1027. container_type m_index;
  1028. std::array<common_iter, N> m_common_iterator;
  1029. };
  1030. enum class broadcast_trivial { non_trivial, c_trivial, f_trivial };
  1031. // Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial
  1032. // enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a
  1033. // singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage
  1034. // buffer; returns `non_trivial` otherwise.
  1035. template <size_t N>
  1036. broadcast_trivial broadcast(const std::array<buffer_info, N> &buffers, size_t &ndim, std::vector<size_t> &shape) {
  1037. ndim = std::accumulate(buffers.begin(), buffers.end(), size_t(0), [](size_t res, const buffer_info& buf) {
  1038. return std::max(res, buf.ndim);
  1039. });
  1040. shape.clear();
  1041. shape.resize(ndim, 1);
  1042. // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or
  1043. // the full size).
  1044. for (size_t i = 0; i < N; ++i) {
  1045. auto res_iter = shape.rbegin();
  1046. auto end = buffers[i].shape.rend();
  1047. for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) {
  1048. const auto &dim_size_in = *shape_iter;
  1049. auto &dim_size_out = *res_iter;
  1050. // Each input dimension can either be 1 or `n`, but `n` values must match across buffers
  1051. if (dim_size_out == 1)
  1052. dim_size_out = dim_size_in;
  1053. else if (dim_size_in != 1 && dim_size_in != dim_size_out)
  1054. pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!");
  1055. }
  1056. }
  1057. bool trivial_broadcast_c = true;
  1058. bool trivial_broadcast_f = true;
  1059. for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) {
  1060. if (buffers[i].size == 1)
  1061. continue;
  1062. // Require the same number of dimensions:
  1063. if (buffers[i].ndim != ndim)
  1064. return broadcast_trivial::non_trivial;
  1065. // Require all dimensions be full-size:
  1066. if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin()))
  1067. return broadcast_trivial::non_trivial;
  1068. // Check for C contiguity (but only if previous inputs were also C contiguous)
  1069. if (trivial_broadcast_c) {
  1070. size_t expect_stride = buffers[i].itemsize;
  1071. auto end = buffers[i].shape.crend();
  1072. for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin();
  1073. trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) {
  1074. if (expect_stride == *stride_iter)
  1075. expect_stride *= *shape_iter;
  1076. else
  1077. trivial_broadcast_c = false;
  1078. }
  1079. }
  1080. // Check for Fortran contiguity (if previous inputs were also F contiguous)
  1081. if (trivial_broadcast_f) {
  1082. size_t expect_stride = buffers[i].itemsize;
  1083. auto end = buffers[i].shape.cend();
  1084. for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin();
  1085. trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) {
  1086. if (expect_stride == *stride_iter)
  1087. expect_stride *= *shape_iter;
  1088. else
  1089. trivial_broadcast_f = false;
  1090. }
  1091. }
  1092. }
  1093. return
  1094. trivial_broadcast_c ? broadcast_trivial::c_trivial :
  1095. trivial_broadcast_f ? broadcast_trivial::f_trivial :
  1096. broadcast_trivial::non_trivial;
  1097. }
  1098. template <typename Func, typename Return, typename... Args>
  1099. struct vectorize_helper {
  1100. typename std::remove_reference<Func>::type f;
  1101. static constexpr size_t N = sizeof...(Args);
  1102. template <typename T>
  1103. explicit vectorize_helper(T&&f) : f(std::forward<T>(f)) { }
  1104. object operator()(array_t<Args, array::forcecast>... args) {
  1105. return run(args..., make_index_sequence<N>());
  1106. }
  1107. template <size_t ... Index> object run(array_t<Args, array::forcecast>&... args, index_sequence<Index...> index) {
  1108. /* Request buffers from all parameters */
  1109. std::array<buffer_info, N> buffers {{ args.request()... }};
  1110. /* Determine dimensions parameters of output array */
  1111. size_t ndim = 0;
  1112. std::vector<size_t> shape(0);
  1113. auto trivial = broadcast(buffers, ndim, shape);
  1114. size_t size = 1;
  1115. std::vector<size_t> strides(ndim);
  1116. if (ndim > 0) {
  1117. if (trivial == broadcast_trivial::f_trivial) {
  1118. strides[0] = sizeof(Return);
  1119. for (size_t i = 1; i < ndim; ++i) {
  1120. strides[i] = strides[i - 1] * shape[i - 1];
  1121. size *= shape[i - 1];
  1122. }
  1123. size *= shape[ndim - 1];
  1124. }
  1125. else {
  1126. strides[ndim-1] = sizeof(Return);
  1127. for (size_t i = ndim - 1; i > 0; --i) {
  1128. strides[i - 1] = strides[i] * shape[i];
  1129. size *= shape[i];
  1130. }
  1131. size *= shape[0];
  1132. }
  1133. }
  1134. if (size == 1)
  1135. return cast(f(*reinterpret_cast<Args *>(buffers[Index].ptr)...));
  1136. array_t<Return> result(shape, strides);
  1137. auto buf = result.request();
  1138. auto output = (Return *) buf.ptr;
  1139. /* Call the function */
  1140. if (trivial == broadcast_trivial::non_trivial) {
  1141. apply_broadcast<Index...>(buffers, buf, index);
  1142. } else {
  1143. for (size_t i = 0; i < size; ++i)
  1144. output[i] = f((reinterpret_cast<Args *>(buffers[Index].ptr)[buffers[Index].size == 1 ? 0 : i])...);
  1145. }
  1146. return result;
  1147. }
  1148. template <size_t... Index>
  1149. void apply_broadcast(const std::array<buffer_info, N> &buffers,
  1150. buffer_info &output, index_sequence<Index...>) {
  1151. using input_iterator = multi_array_iterator<N>;
  1152. using output_iterator = array_iterator<Return>;
  1153. input_iterator input_iter(buffers, output.shape);
  1154. output_iterator output_end = array_end<Return>(output);
  1155. for (output_iterator iter = array_begin<Return>(output);
  1156. iter != output_end; ++iter, ++input_iter) {
  1157. *iter = f((input_iter.template data<Index, Args>())...);
  1158. }
  1159. }
  1160. };
  1161. template <typename T, int Flags> struct handle_type_name<array_t<T, Flags>> {
  1162. static PYBIND11_DESCR name() {
  1163. return _("numpy.ndarray[") + npy_format_descriptor<T>::name() + _("]");
  1164. }
  1165. };
  1166. NAMESPACE_END(detail)
  1167. template <typename Func, typename Return, typename... Args>
  1168. detail::vectorize_helper<Func, Return, Args...>
  1169. vectorize(const Func &f, Return (*) (Args ...)) {
  1170. return detail::vectorize_helper<Func, Return, Args...>(f);
  1171. }
  1172. template <typename Return, typename... Args>
  1173. detail::vectorize_helper<Return (*) (Args ...), Return, Args...>
  1174. vectorize(Return (*f) (Args ...)) {
  1175. return vectorize<Return (*) (Args ...), Return, Args...>(f, f);
  1176. }
  1177. template <typename Func, typename FuncType = typename detail::remove_class<decltype(&std::remove_reference<Func>::type::operator())>::type>
  1178. auto vectorize(Func &&f) -> decltype(
  1179. vectorize(std::forward<Func>(f), (FuncType *) nullptr)) {
  1180. return vectorize(std::forward<Func>(f), (FuncType *) nullptr);
  1181. }
  1182. NAMESPACE_END(pybind11)
  1183. #if defined(_MSC_VER)
  1184. #pragma warning(pop)
  1185. #endif