You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							218 lines
						
					
					
						
							7.5 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							218 lines
						
					
					
						
							7.5 KiB
						
					
					
				
								#include <pybind11/embed.h>
							 | 
						|
								#include <catch.hpp>
							 | 
						|
								
							 | 
						|
								#include <thread>
							 | 
						|
								
							 | 
						|
								namespace py = pybind11;
							 | 
						|
								using namespace py::literals;
							 | 
						|
								
							 | 
						|
								class Widget {
							 | 
						|
								public:
							 | 
						|
								    Widget(std::string message) : message(message) { }
							 | 
						|
								    virtual ~Widget() = default;
							 | 
						|
								
							 | 
						|
								    std::string the_message() const { return message; }
							 | 
						|
								    virtual int the_answer() const = 0;
							 | 
						|
								
							 | 
						|
								private:
							 | 
						|
								    std::string message;
							 | 
						|
								};
							 | 
						|
								
							 | 
						|
								class PyWidget final : public Widget {
							 | 
						|
								    using Widget::Widget;
							 | 
						|
								
							 | 
						|
								    int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); }
							 | 
						|
								};
							 | 
						|
								
							 | 
						|
								PYBIND11_EMBEDDED_MODULE(widget_module, m) {
							 | 
						|
								    py::class_<Widget, PyWidget>(m, "Widget")
							 | 
						|
								        .def(py::init<std::string>())
							 | 
						|
								        .def_property_readonly("the_message", &Widget::the_message);
							 | 
						|
								
							 | 
						|
								    m.def("add", [](int i, int j) { return i + j; });
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
							 | 
						|
								    throw std::runtime_error("C++ Error");
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
							 | 
						|
								    auto d = py::dict();
							 | 
						|
								    d["missing"].cast<py::object>();
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
							 | 
						|
								    auto module = py::module::import("test_interpreter");
							 | 
						|
								    REQUIRE(py::hasattr(module, "DerivedWidget"));
							 | 
						|
								
							 | 
						|
								    auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
							 | 
						|
								    py::exec(R"(
							 | 
						|
								        widget = DerivedWidget("{} - {}".format(hello, x))
							 | 
						|
								        message = widget.the_message
							 | 
						|
								    )", py::globals(), locals);
							 | 
						|
								    REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
							 | 
						|
								
							 | 
						|
								    auto py_widget = module.attr("DerivedWidget")("The question");
							 | 
						|
								    auto message = py_widget.attr("the_message");
							 | 
						|
								    REQUIRE(message.cast<std::string>() == "The question");
							 | 
						|
								
							 | 
						|
								    const auto &cpp_widget = py_widget.cast<const Widget &>();
							 | 
						|
								    REQUIRE(cpp_widget.the_answer() == 42);
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Import error handling") {
							 | 
						|
								    REQUIRE_NOTHROW(py::module::import("widget_module"));
							 | 
						|
								    REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
							 | 
						|
								                        "ImportError: C++ Error");
							 | 
						|
								    REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
							 | 
						|
								                        Catch::Contains("ImportError: KeyError"));
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("There can be only one interpreter") {
							 | 
						|
								    static_assert(std::is_move_constructible<py::scoped_interpreter>::value, "");
							 | 
						|
								    static_assert(!std::is_move_assignable<py::scoped_interpreter>::value, "");
							 | 
						|
								    static_assert(!std::is_copy_constructible<py::scoped_interpreter>::value, "");
							 | 
						|
								    static_assert(!std::is_copy_assignable<py::scoped_interpreter>::value, "");
							 | 
						|
								
							 | 
						|
								    REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
							 | 
						|
								    REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
							 | 
						|
								
							 | 
						|
								    py::finalize_interpreter();
							 | 
						|
								    REQUIRE_NOTHROW(py::scoped_interpreter());
							 | 
						|
								    {
							 | 
						|
								        auto pyi1 = py::scoped_interpreter();
							 | 
						|
								        auto pyi2 = std::move(pyi1);
							 | 
						|
								    }
							 | 
						|
								    py::initialize_interpreter();
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								bool has_pybind11_internals_builtin() {
							 | 
						|
								    auto builtins = py::handle(PyEval_GetBuiltins());
							 | 
						|
								    return builtins.contains(PYBIND11_INTERNALS_ID);
							 | 
						|
								};
							 | 
						|
								
							 | 
						|
								bool has_pybind11_internals_static() {
							 | 
						|
								    return py::detail::get_internals_ptr() != nullptr;
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Restart the interpreter") {
							 | 
						|
								    // Verify pre-restart state.
							 | 
						|
								    REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
							 | 
						|
								    REQUIRE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    // Restart the interpreter.
							 | 
						|
								    py::finalize_interpreter();
							 | 
						|
								    REQUIRE(Py_IsInitialized() == 0);
							 | 
						|
								
							 | 
						|
								    py::initialize_interpreter();
							 | 
						|
								    REQUIRE(Py_IsInitialized() == 1);
							 | 
						|
								
							 | 
						|
								    // Internals are deleted after a restart.
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_static());
							 | 
						|
								    pybind11::detail::get_internals();
							 | 
						|
								    REQUIRE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    // Make sure that an interpreter with no get_internals() created until finalize still gets the
							 | 
						|
								    // internals destroyed
							 | 
						|
								    py::finalize_interpreter();
							 | 
						|
								    py::initialize_interpreter();
							 | 
						|
								    bool ran = false;
							 | 
						|
								    py::module::import("__main__").attr("internals_destroy_test") =
							 | 
						|
								        py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_static());
							 | 
						|
								    REQUIRE_FALSE(ran);
							 | 
						|
								    py::finalize_interpreter();
							 | 
						|
								    REQUIRE(ran);
							 | 
						|
								    py::initialize_interpreter();
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    // C++ modules can be reloaded.
							 | 
						|
								    auto cpp_module = py::module::import("widget_module");
							 | 
						|
								    REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
							 | 
						|
								
							 | 
						|
								    // C++ type information is reloaded and can be used in python modules.
							 | 
						|
								    auto py_module = py::module::import("test_interpreter");
							 | 
						|
								    auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
							 | 
						|
								    REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Subinterpreter") {
							 | 
						|
								    // Add tags to the modules in the main interpreter and test the basics.
							 | 
						|
								    py::module::import("__main__").attr("main_tag") = "main interpreter";
							 | 
						|
								    {
							 | 
						|
								        auto m = py::module::import("widget_module");
							 | 
						|
								        m.attr("extension_module_tag") = "added to module in main interpreter";
							 | 
						|
								
							 | 
						|
								        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
							 | 
						|
								    }
							 | 
						|
								    REQUIRE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    /// Create and switch to a subinterpreter.
							 | 
						|
								    auto main_tstate = PyThreadState_Get();
							 | 
						|
								    auto sub_tstate = Py_NewInterpreter();
							 | 
						|
								
							 | 
						|
								    // Subinterpreters get their own copy of builtins. detail::get_internals() still
							 | 
						|
								    // works by returning from the static variable, i.e. all interpreters share a single
							 | 
						|
								    // global pybind11::internals;
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_builtin());
							 | 
						|
								    REQUIRE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    // Modules tags should be gone.
							 | 
						|
								    REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
							 | 
						|
								    {
							 | 
						|
								        auto m = py::module::import("widget_module");
							 | 
						|
								        REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
							 | 
						|
								
							 | 
						|
								        // Function bindings should still work.
							 | 
						|
								        REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
							 | 
						|
								    }
							 | 
						|
								
							 | 
						|
								    // Restore main interpreter.
							 | 
						|
								    Py_EndInterpreter(sub_tstate);
							 | 
						|
								    PyThreadState_Swap(main_tstate);
							 | 
						|
								
							 | 
						|
								    REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
							 | 
						|
								    REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Execution frame") {
							 | 
						|
								    // When the interpreter is embedded, there is no execution frame, but `py::exec`
							 | 
						|
								    // should still function by using reasonable globals: `__main__.__dict__`.
							 | 
						|
								    py::exec("var = dict(number=42)");
							 | 
						|
								    REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								TEST_CASE("Threads") {
							 | 
						|
								    // Restart interpreter to ensure threads are not initialized
							 | 
						|
								    py::finalize_interpreter();
							 | 
						|
								    py::initialize_interpreter();
							 | 
						|
								    REQUIRE_FALSE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								    constexpr auto num_threads = 10;
							 | 
						|
								    auto locals = py::dict("count"_a=0);
							 | 
						|
								
							 | 
						|
								    {
							 | 
						|
								        py::gil_scoped_release gil_release{};
							 | 
						|
								        REQUIRE(has_pybind11_internals_static());
							 | 
						|
								
							 | 
						|
								        auto threads = std::vector<std::thread>();
							 | 
						|
								        for (auto i = 0; i < num_threads; ++i) {
							 | 
						|
								            threads.emplace_back([&]() {
							 | 
						|
								                py::gil_scoped_acquire gil{};
							 | 
						|
								                locals["count"] = locals["count"].cast<int>() + 1;
							 | 
						|
								            });
							 | 
						|
								        }
							 | 
						|
								
							 | 
						|
								        for (auto &thread : threads) {
							 | 
						|
								            thread.join();
							 | 
						|
								        }
							 | 
						|
								    }
							 | 
						|
								
							 | 
						|
								    REQUIRE(locals["count"].cast<int>() == num_threads);
							 | 
						|
								}
							 |