You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

218 lines
7.5 KiB

  1. #include <pybind11/embed.h>
  2. #include <catch.hpp>
  3. #include <thread>
  4. namespace py = pybind11;
  5. using namespace py::literals;
  6. class Widget {
  7. public:
  8. Widget(std::string message) : message(message) { }
  9. virtual ~Widget() = default;
  10. std::string the_message() const { return message; }
  11. virtual int the_answer() const = 0;
  12. private:
  13. std::string message;
  14. };
  15. class PyWidget final : public Widget {
  16. using Widget::Widget;
  17. int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
  18. };
  19. PYBIND11_EMBEDDED_MODULE(widget_module, m) {
  20. py::class_<Widget, PyWidget>(m, "Widget")
  21. .def(py::init<std::string>())
  22. .def_property_readonly("the_message", &Widget::the_message);
  23. m.def("add", [](int i, int j) { return i + j; });
  24. }
  25. PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
  26. throw std::runtime_error("C++ Error");
  27. }
  28. PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
  29. auto d = py::dict();
  30. d["missing"].cast<py::object>();
  31. }
  32. TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
  33. auto module = py::module::import("test_interpreter");
  34. REQUIRE(py::hasattr(module, "DerivedWidget"));
  35. auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
  36. py::exec(R"(
  37. widget = DerivedWidget("{} - {}".format(hello, x))
  38. message = widget.the_message
  39. )", py::globals(), locals);
  40. REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
  41. auto py_widget = module.attr("DerivedWidget")("The question");
  42. auto message = py_widget.attr("the_message");
  43. REQUIRE(message.cast<std::string>() == "The question");
  44. const auto &cpp_widget = py_widget.cast<const Widget &>();
  45. REQUIRE(cpp_widget.the_answer() == 42);
  46. }
  47. TEST_CASE("Import error handling") {
  48. REQUIRE_NOTHROW(py::module::import("widget_module"));
  49. REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
  50. "ImportError: C++ Error");
  51. REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
  52. Catch::Contains("ImportError: KeyError"));
  53. }
  54. TEST_CASE("There can be only one interpreter") {
  55. static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
  56. static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
  57. static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
  58. static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
  59. REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
  60. REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
  61. py::finalize_interpreter();
  62. REQUIRE_NOTHROW(py::scoped_interpreter());
  63. {
  64. auto pyi1 = py::scoped_interpreter();
  65. auto pyi2 = std::move(pyi1);
  66. }
  67. py::initialize_interpreter();
  68. }
  69. bool has_pybind11_internals_builtin() {
  70. auto builtins = py::handle(PyEval_GetBuiltins());
  71. return builtins.contains(PYBIND11_INTERNALS_ID);
  72. };
  73. bool has_pybind11_internals_static() {
  74. return py::detail::get_internals_ptr() != nullptr;
  75. }
  76. TEST_CASE("Restart the interpreter") {
  77. // Verify pre-restart state.
  78. REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
  79. REQUIRE(has_pybind11_internals_builtin());
  80. REQUIRE(has_pybind11_internals_static());
  81. // Restart the interpreter.
  82. py::finalize_interpreter();
  83. REQUIRE(Py_IsInitialized() == 0);
  84. py::initialize_interpreter();
  85. REQUIRE(Py_IsInitialized() == 1);
  86. // Internals are deleted after a restart.
  87. REQUIRE_FALSE(has_pybind11_internals_builtin());
  88. REQUIRE_FALSE(has_pybind11_internals_static());
  89. pybind11::detail::get_internals();
  90. REQUIRE(has_pybind11_internals_builtin());
  91. REQUIRE(has_pybind11_internals_static());
  92. // Make sure that an interpreter with no get_internals() created until finalize still gets the
  93. // internals destroyed
  94. py::finalize_interpreter();
  95. py::initialize_interpreter();
  96. bool ran = false;
  97. py::module::import("__main__").attr("internals_destroy_test") =
  98. py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
  99. REQUIRE_FALSE(has_pybind11_internals_builtin());
  100. REQUIRE_FALSE(has_pybind11_internals_static());
  101. REQUIRE_FALSE(ran);
  102. py::finalize_interpreter();
  103. REQUIRE(ran);
  104. py::initialize_interpreter();
  105. REQUIRE_FALSE(has_pybind11_internals_builtin());
  106. REQUIRE_FALSE(has_pybind11_internals_static());
  107. // C++ modules can be reloaded.
  108. auto cpp_module = py::module::import("widget_module");
  109. REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
  110. // C++ type information is reloaded and can be used in python modules.
  111. auto py_module = py::module::import("test_interpreter");
  112. auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
  113. REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
  114. }
  115. TEST_CASE("Subinterpreter") {
  116. // Add tags to the modules in the main interpreter and test the basics.
  117. py::module::import("__main__").attr("main_tag") = "main interpreter";
  118. {
  119. auto m = py::module::import("widget_module");
  120. m.attr("extension_module_tag") = "added to module in main interpreter";
  121. REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
  122. }
  123. REQUIRE(has_pybind11_internals_builtin());
  124. REQUIRE(has_pybind11_internals_static());
  125. /// Create and switch to a subinterpreter.
  126. auto main_tstate = PyThreadState_Get();
  127. auto sub_tstate = Py_NewInterpreter();
  128. // Subinterpreters get their own copy of builtins. detail::get_internals() still
  129. // works by returning from the static variable, i.e. all interpreters share a single
  130. // global pybind11::internals;
  131. REQUIRE_FALSE(has_pybind11_internals_builtin());
  132. REQUIRE(has_pybind11_internals_static());
  133. // Modules tags should be gone.
  134. REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
  135. {
  136. auto m = py::module::import("widget_module");
  137. REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
  138. // Function bindings should still work.
  139. REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
  140. }
  141. // Restore main interpreter.
  142. Py_EndInterpreter(sub_tstate);
  143. PyThreadState_Swap(main_tstate);
  144. REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
  145. REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
  146. }
  147. TEST_CASE("Execution frame") {
  148. // When the interpreter is embedded, there is no execution frame, but `py::exec`
  149. // should still function by using reasonable globals: `__main__.__dict__`.
  150. py::exec("var = dict(number=42)");
  151. REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
  152. }
  153. TEST_CASE("Threads") {
  154. // Restart interpreter to ensure threads are not initialized
  155. py::finalize_interpreter();
  156. py::initialize_interpreter();
  157. REQUIRE_FALSE(has_pybind11_internals_static());
  158. constexpr auto num_threads = 10;
  159. auto locals = py::dict("count"_a=0);
  160. {
  161. py::gil_scoped_release gil_release{};
  162. REQUIRE(has_pybind11_internals_static());
  163. auto threads = std::vector<std::thread>();
  164. for (auto i = 0; i < num_threads; ++i) {
  165. threads.emplace_back([&]() {
  166. py::gil_scoped_acquire gil{};
  167. locals["count"] = locals["count"].cast<int>() + 1;
  168. });
  169. }
  170. for (auto &thread : threads) {
  171. thread.join();
  172. }
  173. }
  174. REQUIRE(locals["count"].cast<int>() == num_threads);
  175. }