You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

587 lines
20 KiB

8 years ago
  1. /*
  2. pybind11/std_bind.h: Binding generators for STL data types
  3. Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #pragma once
  8. #include "common.h"
  9. #include "operators.h"
  10. #include <algorithm>
  11. #include <sstream>
  12. NAMESPACE_BEGIN(pybind11)
  13. NAMESPACE_BEGIN(detail)
  14. /* SFINAE helper class used by 'is_comparable */
  15. template <typename T> struct container_traits {
  16. template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
  17. template <typename T2> static std::false_type test_comparable(...);
  18. template <typename T2> static std::true_type test_value(typename T2::value_type *);
  19. template <typename T2> static std::false_type test_value(...);
  20. template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
  21. template <typename T2> static std::false_type test_pair(...);
  22. static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
  23. static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
  24. static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
  25. static constexpr const bool is_element = !is_pair && !is_vector;
  26. };
  27. /* Default: is_comparable -> std::false_type */
  28. template <typename T, typename SFINAE = void>
  29. struct is_comparable : std::false_type { };
  30. /* For non-map data structures, check whether operator== can be instantiated */
  31. template <typename T>
  32. struct is_comparable<
  33. T, enable_if_t<container_traits<T>::is_element &&
  34. container_traits<T>::is_comparable>>
  35. : std::true_type { };
  36. /* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
  37. template <typename T>
  38. struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
  39. static constexpr const bool value =
  40. is_comparable<typename T::value_type>::value;
  41. };
  42. /* For pairs, recursively check the two data types */
  43. template <typename T>
  44. struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
  45. static constexpr const bool value =
  46. is_comparable<typename T::first_type>::value &&
  47. is_comparable<typename T::second_type>::value;
  48. };
  49. /* Fallback functions */
  50. template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
  51. template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
  52. template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
  53. template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
  54. template<typename Vector, typename Class_>
  55. void vector_if_copy_constructible(enable_if_t<
  56. std::is_copy_constructible<Vector>::value &&
  57. std::is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
  58. cl.def(init<const Vector &>(), "Copy constructor");
  59. }
  60. template<typename Vector, typename Class_>
  61. void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
  62. using T = typename Vector::value_type;
  63. cl.def(self == self);
  64. cl.def(self != self);
  65. cl.def("count",
  66. [](const Vector &v, const T &x) {
  67. return std::count(v.begin(), v.end(), x);
  68. },
  69. arg("x"),
  70. "Return the number of times ``x`` appears in the list"
  71. );
  72. cl.def("remove", [](Vector &v, const T &x) {
  73. auto p = std::find(v.begin(), v.end(), x);
  74. if (p != v.end())
  75. v.erase(p);
  76. else
  77. throw value_error();
  78. },
  79. arg("x"),
  80. "Remove the first item from the list whose value is x. "
  81. "It is an error if there is no such item."
  82. );
  83. cl.def("__contains__",
  84. [](const Vector &v, const T &x) {
  85. return std::find(v.begin(), v.end(), x) != v.end();
  86. },
  87. arg("x"),
  88. "Return true the container contains ``x``"
  89. );
  90. }
  91. // Vector modifiers -- requires a copyable vector_type:
  92. // (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
  93. // silly to allow deletion but not insertion, so include them here too.)
  94. template <typename Vector, typename Class_>
  95. void vector_modifiers(enable_if_t<std::is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
  96. using T = typename Vector::value_type;
  97. using SizeType = typename Vector::size_type;
  98. using DiffType = typename Vector::difference_type;
  99. cl.def("append",
  100. [](Vector &v, const T &value) { v.push_back(value); },
  101. arg("x"),
  102. "Add an item to the end of the list");
  103. cl.def("__init__", [](Vector &v, iterable it) {
  104. new (&v) Vector();
  105. try {
  106. v.reserve(len(it));
  107. for (handle h : it)
  108. v.push_back(h.cast<T>());
  109. } catch (...) {
  110. v.~Vector();
  111. throw;
  112. }
  113. });
  114. cl.def("extend",
  115. [](Vector &v, const Vector &src) {
  116. v.reserve(v.size() + src.size());
  117. v.insert(v.end(), src.begin(), src.end());
  118. },
  119. arg("L"),
  120. "Extend the list by appending all the items in the given list"
  121. );
  122. cl.def("insert",
  123. [](Vector &v, SizeType i, const T &x) {
  124. v.insert(v.begin() + (DiffType) i, x);
  125. },
  126. arg("i") , arg("x"),
  127. "Insert an item at a given position."
  128. );
  129. cl.def("pop",
  130. [](Vector &v) {
  131. if (v.empty())
  132. throw index_error();
  133. T t = v.back();
  134. v.pop_back();
  135. return t;
  136. },
  137. "Remove and return the last item"
  138. );
  139. cl.def("pop",
  140. [](Vector &v, SizeType i) {
  141. if (i >= v.size())
  142. throw index_error();
  143. T t = v[i];
  144. v.erase(v.begin() + (DiffType) i);
  145. return t;
  146. },
  147. arg("i"),
  148. "Remove and return the item at index ``i``"
  149. );
  150. cl.def("__setitem__",
  151. [](Vector &v, SizeType i, const T &t) {
  152. if (i >= v.size())
  153. throw index_error();
  154. v[i] = t;
  155. }
  156. );
  157. /// Slicing protocol
  158. cl.def("__getitem__",
  159. [](const Vector &v, slice slice) -> Vector * {
  160. size_t start, stop, step, slicelength;
  161. if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
  162. throw error_already_set();
  163. Vector *seq = new Vector();
  164. seq->reserve((size_t) slicelength);
  165. for (size_t i=0; i<slicelength; ++i) {
  166. seq->push_back(v[start]);
  167. start += step;
  168. }
  169. return seq;
  170. },
  171. arg("s"),
  172. "Retrieve list elements using a slice object"
  173. );
  174. cl.def("__setitem__",
  175. [](Vector &v, slice slice, const Vector &value) {
  176. size_t start, stop, step, slicelength;
  177. if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
  178. throw error_already_set();
  179. if (slicelength != value.size())
  180. throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
  181. for (size_t i=0; i<slicelength; ++i) {
  182. v[start] = value[i];
  183. start += step;
  184. }
  185. },
  186. "Assign list elements using a slice object"
  187. );
  188. cl.def("__delitem__",
  189. [](Vector &v, SizeType i) {
  190. if (i >= v.size())
  191. throw index_error();
  192. v.erase(v.begin() + DiffType(i));
  193. },
  194. "Delete the list elements at index ``i``"
  195. );
  196. cl.def("__delitem__",
  197. [](Vector &v, slice slice) {
  198. size_t start, stop, step, slicelength;
  199. if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
  200. throw error_already_set();
  201. if (step == 1 && false) {
  202. v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
  203. } else {
  204. for (size_t i = 0; i < slicelength; ++i) {
  205. v.erase(v.begin() + DiffType(start));
  206. start += step - 1;
  207. }
  208. }
  209. },
  210. "Delete list elements using a slice object"
  211. );
  212. }
  213. // If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
  214. // we have to access by copying; otherwise we return by reference.
  215. template <typename Vector> using vector_needs_copy = negation<
  216. std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
  217. // The usual case: access and iterate by reference
  218. template <typename Vector, typename Class_>
  219. void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
  220. using T = typename Vector::value_type;
  221. using SizeType = typename Vector::size_type;
  222. using ItType = typename Vector::iterator;
  223. cl.def("__getitem__",
  224. [](Vector &v, SizeType i) -> T & {
  225. if (i >= v.size())
  226. throw index_error();
  227. return v[i];
  228. },
  229. return_value_policy::reference_internal // ref + keepalive
  230. );
  231. cl.def("__iter__",
  232. [](Vector &v) {
  233. return make_iterator<
  234. return_value_policy::reference_internal, ItType, ItType, T&>(
  235. v.begin(), v.end());
  236. },
  237. keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
  238. );
  239. }
  240. // The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
  241. template <typename Vector, typename Class_>
  242. void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
  243. using T = typename Vector::value_type;
  244. using SizeType = typename Vector::size_type;
  245. using ItType = typename Vector::iterator;
  246. cl.def("__getitem__",
  247. [](const Vector &v, SizeType i) -> T {
  248. if (i >= v.size())
  249. throw index_error();
  250. return v[i];
  251. }
  252. );
  253. cl.def("__iter__",
  254. [](Vector &v) {
  255. return make_iterator<
  256. return_value_policy::copy, ItType, ItType, T>(
  257. v.begin(), v.end());
  258. },
  259. keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
  260. );
  261. }
  262. template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
  263. -> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
  264. using size_type = typename Vector::size_type;
  265. cl.def("__repr__",
  266. [name](Vector &v) {
  267. std::ostringstream s;
  268. s << name << '[';
  269. for (size_type i=0; i < v.size(); ++i) {
  270. s << v[i];
  271. if (i != v.size() - 1)
  272. s << ", ";
  273. }
  274. s << ']';
  275. return s.str();
  276. },
  277. "Return the canonical string representation of this list."
  278. );
  279. }
  280. // Provide the buffer interface for vectors if we have data() and we have a format for it
  281. // GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
  282. template <typename Vector, typename = void>
  283. struct vector_has_data_and_format : std::false_type {};
  284. template <typename Vector>
  285. struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
  286. // Add the buffer interface to a vector
  287. template <typename Vector, typename Class_, typename... Args>
  288. enable_if_t<detail::any_of<std::is_same<Args, buffer_protocol>...>::value>
  289. vector_buffer(Class_& cl) {
  290. using T = typename Vector::value_type;
  291. static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
  292. // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
  293. format_descriptor<T>::format();
  294. cl.def_buffer([](Vector& v) -> buffer_info {
  295. return buffer_info(v.data(), sizeof(T), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
  296. });
  297. cl.def("__init__", [](Vector& vec, buffer buf) {
  298. auto info = buf.request();
  299. if (info.ndim != 1 || info.strides[0] <= 0 || info.strides[0] % sizeof(T))
  300. throw type_error("Only valid 1D buffers can be copied to a vector");
  301. if (!detail::compare_buffer_info<T>::compare(info) || sizeof(T) != info.itemsize)
  302. throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
  303. new (&vec) Vector();
  304. vec.reserve(info.shape[0]);
  305. T *p = static_cast<T*>(info.ptr);
  306. auto step = info.strides[0] / sizeof(T);
  307. T *end = p + info.shape[0] * step;
  308. for (; p < end; p += step)
  309. vec.push_back(*p);
  310. });
  311. return;
  312. }
  313. template <typename Vector, typename Class_, typename... Args>
  314. enable_if_t<!detail::any_of<std::is_same<Args, buffer_protocol>...>::value> vector_buffer(Class_&) {}
  315. NAMESPACE_END(detail)
  316. //
  317. // std::vector
  318. //
  319. template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
  320. class_<Vector, holder_type> bind_vector(module &m, std::string const &name, Args&&... args) {
  321. using Class_ = class_<Vector, holder_type>;
  322. Class_ cl(m, name.c_str(), std::forward<Args>(args)...);
  323. // Declare the buffer interface if a buffer_protocol() is passed in
  324. detail::vector_buffer<Vector, Class_, Args...>(cl);
  325. cl.def(init<>());
  326. // Register copy constructor (if possible)
  327. detail::vector_if_copy_constructible<Vector, Class_>(cl);
  328. // Register comparison-related operators and functions (if possible)
  329. detail::vector_if_equal_operator<Vector, Class_>(cl);
  330. // Register stream insertion operator (if possible)
  331. detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
  332. // Modifiers require copyable vector value type
  333. detail::vector_modifiers<Vector, Class_>(cl);
  334. // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
  335. detail::vector_accessor<Vector, Class_>(cl);
  336. cl.def("__bool__",
  337. [](const Vector &v) -> bool {
  338. return !v.empty();
  339. },
  340. "Check whether the list is nonempty"
  341. );
  342. cl.def("__len__", &Vector::size);
  343. #if 0
  344. // C++ style functions deprecated, leaving it here as an example
  345. cl.def(init<size_type>());
  346. cl.def("resize",
  347. (void (Vector::*) (size_type count)) & Vector::resize,
  348. "changes the number of elements stored");
  349. cl.def("erase",
  350. [](Vector &v, SizeType i) {
  351. if (i >= v.size())
  352. throw index_error();
  353. v.erase(v.begin() + i);
  354. }, "erases element at index ``i``");
  355. cl.def("empty", &Vector::empty, "checks whether the container is empty");
  356. cl.def("size", &Vector::size, "returns the number of elements");
  357. cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
  358. cl.def("pop_back", &Vector::pop_back, "removes the last element");
  359. cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements");
  360. cl.def("reserve", &Vector::reserve, "reserves storage");
  361. cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage");
  362. cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
  363. cl.def("clear", &Vector::clear, "clears the contents");
  364. cl.def("swap", &Vector::swap, "swaps the contents");
  365. cl.def("front", [](Vector &v) {
  366. if (v.size()) return v.front();
  367. else throw index_error();
  368. }, "access the first element");
  369. cl.def("back", [](Vector &v) {
  370. if (v.size()) return v.back();
  371. else throw index_error();
  372. }, "access the last element ");
  373. #endif
  374. return cl;
  375. }
  376. //
  377. // std::map, std::unordered_map
  378. //
  379. NAMESPACE_BEGIN(detail)
  380. /* Fallback functions */
  381. template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
  382. template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
  383. // Map assignment when copy-assignable: just copy the value
  384. template <typename Map, typename Class_>
  385. void map_assignment(enable_if_t<std::is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
  386. using KeyType = typename Map::key_type;
  387. using MappedType = typename Map::mapped_type;
  388. cl.def("__setitem__",
  389. [](Map &m, const KeyType &k, const MappedType &v) {
  390. auto it = m.find(k);
  391. if (it != m.end()) it->second = v;
  392. else m.emplace(k, v);
  393. }
  394. );
  395. }
  396. // Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
  397. template<typename Map, typename Class_>
  398. void map_assignment(enable_if_t<
  399. !std::is_copy_assignable<typename Map::mapped_type>::value &&
  400. std::is_copy_constructible<typename Map::mapped_type>::value,
  401. Class_> &cl) {
  402. using KeyType = typename Map::key_type;
  403. using MappedType = typename Map::mapped_type;
  404. cl.def("__setitem__",
  405. [](Map &m, const KeyType &k, const MappedType &v) {
  406. // We can't use m[k] = v; because value type might not be default constructable
  407. auto r = m.emplace(k, v);
  408. if (!r.second) {
  409. // value type is not copy assignable so the only way to insert it is to erase it first...
  410. m.erase(r.first);
  411. m.emplace(k, v);
  412. }
  413. }
  414. );
  415. }
  416. template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
  417. -> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
  418. cl.def("__repr__",
  419. [name](Map &m) {
  420. std::ostringstream s;
  421. s << name << '{';
  422. bool f = false;
  423. for (auto const &kv : m) {
  424. if (f)
  425. s << ", ";
  426. s << kv.first << ": " << kv.second;
  427. f = true;
  428. }
  429. s << '}';
  430. return s.str();
  431. },
  432. "Return the canonical string representation of this map."
  433. );
  434. }
  435. NAMESPACE_END(detail)
  436. template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
  437. class_<Map, holder_type> bind_map(module &m, const std::string &name, Args&&... args) {
  438. using KeyType = typename Map::key_type;
  439. using MappedType = typename Map::mapped_type;
  440. using Class_ = class_<Map, holder_type>;
  441. Class_ cl(m, name.c_str(), std::forward<Args>(args)...);
  442. cl.def(init<>());
  443. // Register stream insertion operator (if possible)
  444. detail::map_if_insertion_operator<Map, Class_>(cl, name);
  445. cl.def("__bool__",
  446. [](const Map &m) -> bool { return !m.empty(); },
  447. "Check whether the map is nonempty"
  448. );
  449. cl.def("__iter__",
  450. [](Map &m) { return make_key_iterator(m.begin(), m.end()); },
  451. keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
  452. );
  453. cl.def("items",
  454. [](Map &m) { return make_iterator(m.begin(), m.end()); },
  455. keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
  456. );
  457. cl.def("__getitem__",
  458. [](Map &m, const KeyType &k) -> MappedType & {
  459. auto it = m.find(k);
  460. if (it == m.end())
  461. throw key_error();
  462. return it->second;
  463. },
  464. return_value_policy::reference_internal // ref + keepalive
  465. );
  466. // Assignment provided only if the type is copyable
  467. detail::map_assignment<Map, Class_>(cl);
  468. cl.def("__delitem__",
  469. [](Map &m, const KeyType &k) {
  470. auto it = m.find(k);
  471. if (it == m.end())
  472. throw key_error();
  473. return m.erase(it);
  474. }
  475. );
  476. cl.def("__len__", &Map::size);
  477. return cl;
  478. }
  479. NAMESPACE_END(pybind11)