Skip to content

Commit

Permalink
Fix gil cross threads between C++ and Python
Browse files Browse the repository at this point in the history
  • Loading branch information
auxten committed May 8, 2024
1 parent 494047b commit c44c2be
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 13 deletions.
95 changes: 84 additions & 11 deletions src/TableFunctions/TableFunctionPython.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#include <Storages/StoragePython.h>
#include <TableFunctions/TableFunctionFactory.h>
#include <TableFunctions/TableFunctionPython.h>
#include <pybind11/gil.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

Expand All @@ -21,8 +23,73 @@ extern const int PY_OBJECT_NOT_FOUND;
extern const int PY_EXCEPTION_OCCURED;
}

// Helper function to check if an object's class is or inherits from PyReader with a maximum depth
bool is_or_inherits_from_pyreader(const py::handle & obj, int depth = 3)
{
// Base case: if depth limit reached, stop the recursion
if (depth == 0)
return false;

// Check directly if obj is an instance of PyReader
if (py::isinstance(obj, py::module_::import("chdb").attr("PyReader")))
return true;

// Check if obj's class or any of its bases is PyReader
py::object cls = obj.attr("__class__");
if (py::hasattr(cls, "__bases__"))
{
for (auto base : cls.attr("__bases__"))
if (py::str(base.attr("__name__")).cast<std::string>() == "PyReader" || is_or_inherits_from_pyreader(base, depth - 1))
return true;
}
return false;
}

// Function to find instances of PyReader or classes derived from PyReader, filtered by variable name
std::vector<py::object> find_instances_of_pyreader(const std::string & var_name)
{
std::vector<py::object> instances;

// Access the main module and its global dictionary
py::dict globals = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__"));

// Search in global scope
if (globals.contains(var_name))
{
py::object obj = globals[var_name.data()];
if (py::isinstance<py::object>(obj) && py::hasattr(obj, "__class__"))
{
if (is_or_inherits_from_pyreader(obj))
instances.push_back(obj);
}
}
if (!instances.empty())
return instances;

// Check objects in the garbage collector if nothing found, filtering by var_name
// typically used to find objects that are not in the global scope, like in functions
LOG_DEBUG(&Poco::Logger::get("TableFunctionPython"), "Searching for PyReader objects in the garbage collector");
py::module_ gc = py::module_::import("gc");
py::list all_objects = gc.attr("get_objects")();

for (auto obj : all_objects)
{
if (py::isinstance<py::object>(obj) && py::hasattr(obj, "__class__"))
{
if (is_or_inherits_from_pyreader(obj) && py::str(obj.attr("__class__").attr("__name__")).cast<std::string>() == var_name)
{
if (std::find(instances.begin(), instances.end(), obj) == instances.end())
instances.push_back(obj.cast<py::object>());
}
}
}

return instances;
}

void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr context)
{
py::gil_scoped_acquire acquire;
const auto & func_args = ast_function->as<ASTFunction &>();

if (!func_args.arguments)
Expand All @@ -37,10 +104,10 @@ void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr

try
{
py::dict global_vars = py::globals();
LOG_DEBUG(logger, "Globals content: {}", String(py::str(global_vars)));
py::dict main_vars = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__").ptr());
LOG_DEBUG(logger, "Main content: {}", String(py::str(main_vars)));
// py::dict global_vars = py::globals();
// LOG_DEBUG(logger, "Globals content: {}", String(py::str(global_vars)));
// py::dict main_vars = py::reinterpret_borrow<py::dict>(py::module_::import("__main__").attr("__dict__").ptr());
// LOG_DEBUG(logger, "Main content: {}", String(py::str(main_vars)));

// get the py_reader_arg without quotes
auto py_reader_arg_str = py_reader_arg->as<ASTLiteral &>().value.safeGet<String>();
Expand All @@ -51,14 +118,19 @@ void TableFunctionPython::parseArguments(const ASTPtr & ast_function, ContextPtr
std::remove_if(py_reader_arg_str.begin(), py_reader_arg_str.end(), [](char c) { return c == '\'' || c == '\"' || c == '`'; }),
py_reader_arg_str.end());

// try global_vars first, if not found, try main_vars
py::object obj_by_name
= global_vars.contains(py_reader_arg_str.data()) ? global_vars[py_reader_arg_str.data()] : main_vars[py_reader_arg_str.data()];
auto instances = find_instances_of_pyreader(py_reader_arg_str);
if (instances.empty())
throw Exception(ErrorCodes::PY_OBJECT_NOT_FOUND, "PyReader object not found in the Python environment");
if (instances.size() > 1)
throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Multiple PyReader objects found in the Python environment");

LOG_DEBUG(
logger,
"PyReader object found in Python environment with name: {} type: {}",
py_reader_arg_str,
py::str(instances[0].attr("__class__")).cast<std::string>());

// check if obj_by_name is a PyReader object or a subclass object of PyReader
if (!py::isinstance<PyReader>(obj_by_name))
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Python object is not a PyReader object");
reader = std::dynamic_pointer_cast<PyReader>(obj_by_name.cast<std::shared_ptr<PyReader>>());
reader = instances[0];
}
catch (py::error_already_set & e)
{
Expand All @@ -73,6 +145,7 @@ StoragePtr TableFunctionPython::executeImpl(
ColumnsDescription /*cached_columns*/,
bool is_insert_query) const
{
py::gil_scoped_acquire acquire;
if (!reader)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Python reader not initialized");

Expand Down
12 changes: 10 additions & 2 deletions src/TableFunctions/TableFunctionPython.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#pragma once

#include <Storages/ColumnsDescription.h>
#include <Storages/StoragePython.h>
#include <TableFunctions/ITableFunction.h>
#include <pybind11/pytypes.h>
#include <Poco/Logger.h>
#include "Storages/ColumnsDescription.h"

namespace DB
{
Expand All @@ -13,6 +14,13 @@ class TableFunctionPython : public ITableFunction
public:
static constexpr auto name = "python";
std::string getName() const override { return name; }
~TableFunctionPython() override
{
// Acquire the GIL before destroying the reader object
py::gil_scoped_acquire acquire;
reader.dec_ref();
reader.release();
}

private:
Poco::Logger * logger = &Poco::Logger::get("TableFunctionPython");
Expand All @@ -27,7 +35,7 @@ class TableFunctionPython : public ITableFunction
void parseArguments(const ASTPtr & ast_function, ContextPtr context) override;

ColumnsDescription getActualTableStructure(ContextPtr context, bool is_insert_query) const override;
std::shared_ptr<PyReader> reader;
py::object reader;
};

}

0 comments on commit c44c2be

Please sign in to comment.