The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

184 lines
5.6 KiB

2 months ago
  1. /*
  2. tests/test_callbacks.cpp -- callbacks
  3. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  4. All rights reserved. Use of this source code is governed by a
  5. BSD-style license that can be found in the LICENSE file.
  6. */
  7. #include "pybind11_tests.h"
  8. #include "constructor_stats.h"
  9. #include <pybind11/functional.h>
  10. py::object test_callback1(py::object func) {
  11. return func();
  12. }
  13. py::tuple test_callback2(py::object func) {
  14. return func("Hello", 'x', true, 5);
  15. }
  16. std::string test_callback3(const std::function<int(int)> &func) {
  17. return "func(43) = " + std::to_string(func(43));
  18. }
  19. std::function<int(int)> test_callback4() {
  20. return [](int i) { return i+1; };
  21. }
  22. py::cpp_function test_callback5() {
  23. return py::cpp_function([](int i) { return i+1; },
  24. py::arg("number"));
  25. }
  26. int dummy_function(int i) { return i + 1; }
  27. int dummy_function2(int i, int j) { return i + j; }
  28. std::function<int(int)> roundtrip(std::function<int(int)> f, bool expect_none = false) {
  29. if (expect_none && f) {
  30. throw std::runtime_error("Expected None to be converted to empty std::function");
  31. }
  32. return f;
  33. }
  34. std::string test_dummy_function(const std::function<int(int)> &f) {
  35. using fn_type = int (*)(int);
  36. auto result = f.target<fn_type>();
  37. if (!result) {
  38. auto r = f(1);
  39. return "can't convert to function pointer: eval(1) = " + std::to_string(r);
  40. } else if (*result == dummy_function) {
  41. auto r = (*result)(1);
  42. return "matches dummy_function: eval(1) = " + std::to_string(r);
  43. } else {
  44. return "argument does NOT match dummy_function. This should never happen!";
  45. }
  46. }
  47. struct Payload {
  48. Payload() {
  49. print_default_created(this);
  50. }
  51. ~Payload() {
  52. print_destroyed(this);
  53. }
  54. Payload(const Payload &) {
  55. print_copy_created(this);
  56. }
  57. Payload(Payload &&) {
  58. print_move_created(this);
  59. }
  60. };
  61. class AbstractBase {
  62. public:
  63. virtual unsigned int func() = 0;
  64. };
  65. void func_accepting_func_accepting_base(std::function<double(AbstractBase&)>) { }
  66. struct MovableObject {
  67. bool valid = true;
  68. MovableObject() = default;
  69. MovableObject(const MovableObject &) = default;
  70. MovableObject &operator=(const MovableObject &) = default;
  71. MovableObject(MovableObject &&o) : valid(o.valid) { o.valid = false; }
  72. MovableObject &operator=(MovableObject &&o) {
  73. valid = o.valid;
  74. o.valid = false;
  75. return *this;
  76. }
  77. };
  78. test_initializer callbacks([](py::module &m) {
  79. m.def("test_callback1", &test_callback1);
  80. m.def("test_callback2", &test_callback2);
  81. m.def("test_callback3", &test_callback3);
  82. m.def("test_callback4", &test_callback4);
  83. m.def("test_callback5", &test_callback5);
  84. // Test keyword args and generalized unpacking
  85. m.def("test_tuple_unpacking", [](py::function f) {
  86. auto t1 = py::make_tuple(2, 3);
  87. auto t2 = py::make_tuple(5, 6);
  88. return f("positional", 1, *t1, 4, *t2);
  89. });
  90. m.def("test_dict_unpacking", [](py::function f) {
  91. auto d1 = py::dict("key"_a="value", "a"_a=1);
  92. auto d2 = py::dict();
  93. auto d3 = py::dict("b"_a=2);
  94. return f("positional", 1, **d1, **d2, **d3);
  95. });
  96. m.def("test_keyword_args", [](py::function f) {
  97. return f("x"_a=10, "y"_a=20);
  98. });
  99. m.def("test_unpacking_and_keywords1", [](py::function f) {
  100. auto args = py::make_tuple(2);
  101. auto kwargs = py::dict("d"_a=4);
  102. return f(1, *args, "c"_a=3, **kwargs);
  103. });
  104. m.def("test_unpacking_and_keywords2", [](py::function f) {
  105. auto kwargs1 = py::dict("a"_a=1);
  106. auto kwargs2 = py::dict("c"_a=3, "d"_a=4);
  107. return f("positional", *py::make_tuple(1), 2, *py::make_tuple(3, 4), 5,
  108. "key"_a="value", **kwargs1, "b"_a=2, **kwargs2, "e"_a=5);
  109. });
  110. m.def("test_unpacking_error1", [](py::function f) {
  111. auto kwargs = py::dict("x"_a=3);
  112. return f("x"_a=1, "y"_a=2, **kwargs); // duplicate ** after keyword
  113. });
  114. m.def("test_unpacking_error2", [](py::function f) {
  115. auto kwargs = py::dict("x"_a=3);
  116. return f(**kwargs, "x"_a=1); // duplicate keyword after **
  117. });
  118. m.def("test_arg_conversion_error1", [](py::function f) {
  119. f(234, UnregisteredType(), "kw"_a=567);
  120. });
  121. m.def("test_arg_conversion_error2", [](py::function f) {
  122. f(234, "expected_name"_a=UnregisteredType(), "kw"_a=567);
  123. });
  124. /* Test cleanup of lambda closure */
  125. m.def("test_cleanup", []() -> std::function<void(void)> {
  126. Payload p;
  127. return [p]() {
  128. /* p should be cleaned up when the returned function is garbage collected */
  129. (void) p;
  130. };
  131. });
  132. /* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */
  133. m.def("dummy_function", &dummy_function);
  134. m.def("dummy_function2", &dummy_function2);
  135. m.def("roundtrip", &roundtrip, py::arg("f"), py::arg("expect_none")=false);
  136. m.def("test_dummy_function", &test_dummy_function);
  137. // Export the payload constructor statistics for testing purposes:
  138. m.def("payload_cstats", &ConstructorStats::get<Payload>);
  139. m.def("func_accepting_func_accepting_base",
  140. func_accepting_func_accepting_base);
  141. py::class_<MovableObject>(m, "MovableObject");
  142. m.def("callback_with_movable", [](std::function<void(MovableObject &)> f) {
  143. auto x = MovableObject();
  144. f(x); // lvalue reference shouldn't move out object
  145. return x.valid; // must still return `true`
  146. });
  147. struct CppBoundMethodTest {};
  148. py::class_<CppBoundMethodTest>(m, "CppBoundMethodTest")
  149. .def(py::init<>())
  150. .def("triple", [](CppBoundMethodTest &, int val) { return 3 * val; });
  151. });