Skip to content

Commit c44c2be

Browse files
committed
Fix gil cross threads between C++ and Python
1 parent 494047b commit c44c2be

File tree

2 files changed

+94
-13
lines changed

2 files changed

+94
-13
lines changed

src/TableFunctions/TableFunctionPython.cpp

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
#include <Storages/StoragePython.h>
77
#include <TableFunctions/TableFunctionFactory.h>
88
#include <TableFunctions/TableFunctionPython.h>
9+
#include <pybind11/gil.h>
910
#include <pybind11/pybind11.h>
1011
#include <pybind11/pytypes.h>
12+
#include <Poco/Logger.h>
1113
#include <Common/Exception.h>
1214
#include <Common/logger_useful.h>
1315

@@ -21,8 +23,73 @@ extern const int PY_OBJECT_NOT_FOUND;
2123
extern const int PY_EXCEPTION_OCCURED;
2224
}
2325

26+
// Helper function to check if an object's class is or inherits from PyReader with a maximum depth
27+
bool is_or_inherits_from_pyreader(const py::handle & obj, int depth = 3)
28+
{
29+
// Base case: if depth limit reached, stop the recursion
30+
if (depth == 0)
31+
return false;
32+
33+
// Check directly if obj is an instance of PyReader
34+
if (py::isinstance(obj, py::module_::import("chdb").attr("PyReader")))
35+
return true;
36+
37+
// Check if obj's class or any of its bases is PyReader
38+
py::object cls = obj.attr("__class__");
39+
if (py::hasattr(cls, "__bases__"))
40+
{
41+
for (auto base : cls.attr("__bases__"))
42+
if (py::str(base.attr("__name__")).cast<std::string>() == "PyReader" || is_or_inherits_from_pyreader(base, depth - 1))
43+
return true;
44+
}
45+
return false;
46+
}
47+
48+
// Function to find instances of PyReader or classes derived from PyReader, filtered by variable name
49+
std::vector<py::object> find_instances_of_pyreader(const std::string & var_name)
50+
{
51+
std::vector<py::object> instances;
52+
53+
// Access the main module and its global dictionary
54+
py::dict globals = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__"));
55+
56+
// Search in global scope
57+
if (globals.contains(var_name))
58+
{
59+
py::object obj = globals[var_name.data()];
60+
if (py::isinstance<py::object>(obj) && py::hasattr(obj, "__class__"))
61+
{
62+
if (is_or_inherits_from_pyreader(obj))
63+
instances.push_back(obj);
64+
}
65+
}
66+
if (!instances.empty())
67+
return instances;
68+
69+
// Check objects in the garbage collector if nothing found, filtering by var_name
70+
// typically used to find objects that are not in the global scope, like in functions
71+
LOG_DEBUG(&Poco::Logger::get("TableFunctionPython"), "Searching for PyReader objects in the garbage collector");
72+
py::module_ gc = py::module_::import("gc");
73+
py::list all_objects = gc.attr("get_objects")();
74+
75+
for (auto obj : all_objects)
76+
{
77+
if (py::isinstance<py::object>(obj) && py::hasattr(obj, "__class__"))
78+
{
79+
if (is_or_inherits_from_pyreader(obj) && py::str(obj.attr("__class__").attr("__name__")).cast<std::string>() == var_name)
80+
{
81+
if (std::find(instances.begin(), instances.end(), obj) == instances.end())
82+
instances.push_back(obj.cast<py::object>());
83+
}
84+
}
85+
}
86+
87+
return instances;
88+
}
89+
2490
void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr context)
2591
{
92+
py::gil_scoped_acquire acquire;
2693
const auto & func_args = ast_function->as<ASTFunction &>();
2794

2895
if (!func_args.arguments)
@@ -37,10 +104,10 @@ void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr
37104

38105
try
39106
{
40-
py::dict global_vars = py::globals();
41-
LOG_DEBUG(logger, "Globals content: {}", String(py::str(global_vars)));
42-
py::dict main_vars = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__").ptr());
43-
LOG_DEBUG(logger, "Main content: {}", String(py::str(main_vars)));
107+
// py::dict global_vars = py::globals();
108+
// LOG_DEBUG(logger, "Globals content: {}", String(py::str(global_vars)));
109+
// py::dict main_vars = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__").ptr());
110+
// LOG_DEBUG(logger, "Main content: {}", String(py::str(main_vars)));
44111

45112
// get the py_reader_arg without quotes
46113
auto py_reader_arg_str = py_reader_arg->as<ASTLiteral &>().value.safeGet<String>();
@@ -51,14 +118,19 @@ void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr
51118
std::remove_if(py_reader_arg_str.begin(), py_reader_arg_str.end(), [](char c) { return c == '\'' || c == '\"' || c == '`'; }),
52119
py_reader_arg_str.end());
53120

54-
// try global_vars first, if not found, try main_vars
55-
py::object obj_by_name
56-
= global_vars.contains(py_reader_arg_str.data()) ? global_vars[py_reader_arg_str.data()] : main_vars[py_reader_arg_str.data()];
121+
auto instances = find_instances_of_pyreader(py_reader_arg_str);
122+
if (instances.empty())
123+
throw Exception(ErrorCodes::PY_OBJECT_NOT_FOUND, "PyReader object not found in the Python environment");
124+
if (instances.size() > 1)
125+
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Multiple PyReader objects found in the Python environment");
126+
127+
LOG_DEBUG(
128+
logger,
129+
"PyReader object found in Python environment with name: {} type: {}",
130+
py_reader_arg_str,
131+
py::str(instances[0].attr("__class__")).cast<std::string>());
57132

58-
// check if obj_by_name is a PyReader object or a subclass object of PyReader
59-
if (!py::isinstance<PyReader>(obj_by_name))
60-
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Python object is not a PyReader object");
61-
reader = std::dynamic_pointer_cast<PyReader>(obj_by_name.cast<std::shared_ptr<PyReader>>());
133+
reader = instances[0];
62134
}
63135
catch (py::error_already_set & e)
64136
{
@@ -73,6 +145,7 @@ StoragePtr TableFunctionPython::executeImpl(
73145
ColumnsDescription /*cached_columns*/,
74146
bool is_insert_query) const
75147
{
148+
py::gil_scoped_acquire acquire;
76149
if (!reader)
77150
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Python reader not initialized");
78151

src/TableFunctions/TableFunctionPython.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#pragma once
22

3+
#include <Storages/ColumnsDescription.h>
34
#include <Storages/StoragePython.h>
45
#include <TableFunctions/ITableFunction.h>
6+
#include <pybind11/pytypes.h>
57
#include <Poco/Logger.h>
6-
#include "Storages/ColumnsDescription.h"
78

89
namespace DB
910
{
@@ -13,6 +14,13 @@ class TableFunctionPython : public ITableFunction
1314
public:
1415
static constexpr auto name = "python";
1516
std::string getName() const override { return name; }
17+
~TableFunctionPython() override
18+
{
19+
// Acquire the GIL before destroying the reader object
20+
py::gil_scoped_acquire acquire;
21+
reader.dec_ref();
22+
reader.release();
23+
}
1624

1725
private:
1826
Poco::Logger * logger = &Poco::Logger::get("TableFunctionPython");
@@ -27,7 +35,7 @@ class TableFunctionPython : public ITableFunction
2735
void parseArguments(const ASTPtr & ast_function, ContextPtr context) override;
2836

2937
ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override;
30-
std::shared_ptr<PyReader> reader;
38+
py::object reader;
3139
};
3240

3341
}

0 commit comments

Comments
 (0)