diff --git a/src/Common/PythonUtils.h b/src/Common/PythonUtils.h index ac9bf1cf5d3..8c927bb8f8f 100644 --- a/src/Common/PythonUtils.h +++ b/src/Common/PythonUtils.h @@ -3,6 +3,7 @@ #include #include // #include +#include #include #include #include @@ -14,6 +15,7 @@ #include #include + namespace DB { @@ -25,6 +27,23 @@ extern const int NOT_IMPLEMENTED; namespace py = pybind11; +struct ColumnWrapper +{ + void * buf; // we may modify the data when cast it to PyObject **, so we need a non-const pointer + size_t row_count; + py::handle data; + DataTypePtr dest_type; + std::string py_type; //py::handle type, eg. numpy.ndarray; + std::string row_format; + std::string encoding; // utf8, utf16, utf32, etc. + std::string name; +}; + +using PyObjectVec = std::vector; +using PyObjectVecPtr = std::shared_ptr; +using PyColumnVec = std::vector; +using PyColumnVecPtr = std::shared_ptr; + // Template wrapper function to handle any return type template auto execWithGIL(Func func, Args &&... args) -> decltype(func(std::forward(args)...)) diff --git a/src/Processors/Sources/PythonSource.cpp b/src/Processors/Sources/PythonSource.cpp index 72cdd4bfe0a..4139ec04e4a 100644 --- a/src/Processors/Sources/PythonSource.cpp +++ b/src/Processors/Sources/PythonSource.cpp @@ -42,11 +42,16 @@ extern const int PY_EXCEPTION_OCCURED; PythonSource::PythonSource( py::object & data_source_, const Block & sample_block_, - const UInt64 max_block_size_, - const size_t stream_index, - const size_t num_streams) + PyColumnVecPtr column_cache, + size_t data_source_row_count, + size_t max_block_size_, + size_t stream_index, + size_t num_streams) : ISource(sample_block_.cloneEmpty()) , data_source(data_source_) + , sample_block(sample_block_) + , column_cache(column_cache) + , data_source_row_count(data_source_row_count) , max_block_size(max_block_size_) , stream_index(stream_index) , num_streams(num_streams) @@ -261,6 +266,8 @@ PythonSource::scanData(const py::object & data, const std::vector & return std::move(block); } + + Chunk PythonSource::scanDataToChunk() { auto names = description.sample_block.getNames(); @@ -276,37 +283,6 @@ Chunk PythonSource::scanDataToChunk() if (names.size() != columns.size()) throw Exception(ErrorCodes::PY_EXCEPTION_OCCURED, "Column cache size mismatch"); - { - // check column cache with GIL holded - py::gil_scoped_acquire acquire; - if (column_cache == nullptr) - { - // fill in the cache - column_cache = std::make_shared(columns.size()); - for (size_t i = 0; i < columns.size(); ++i) - { - const auto & col_name = names[i]; - auto & col = (*column_cache)[i]; - col.name = col_name; - try - { - py::object col_data = data_source[py::str(col_name)]; - col.buf = const_cast(tryGetPyArray(col_data, col.data, col.py_type, col.row_count)); - if (col.buf == nullptr) - throw Exception( - ErrorCodes::PY_EXCEPTION_OCCURED, "Convert to array failed for column {} type {}", col_name, col.py_type); - col.dest_type = description.sample_block.getByPosition(i).type; - data_source_row_count = col.row_count; - } - catch (const Exception & e) - { - LOG_ERROR(logger, "Error processing column {}: {}", col_name, e.what()); - throw; - } - } - } - } - auto rows_per_stream = data_source_row_count / num_streams; auto start = stream_index * rows_per_stream; auto end = (stream_index + 1) * rows_per_stream; diff --git a/src/Processors/Sources/PythonSource.h b/src/Processors/Sources/PythonSource.h index 7d23eea2c40..23fd8fcc3ff 100644 --- a/src/Processors/Sources/PythonSource.h +++ b/src/Processors/Sources/PythonSource.h @@ -8,7 +8,7 @@ #include #include #include -#include "DataTypes/IDataType.h" +#include namespace DB { @@ -17,28 +17,18 @@ namespace py = pybind11; class PyReader; -struct ColumnWrapper -{ - void * buf; // we may modify the data when cast it to PyObject **, so we need a non-const pointer - size_t row_count; - py::handle data; - DataTypePtr dest_type; - std::string py_type; //py::handle type, eg. numpy.ndarray; - std::string row_format; - std::string encoding; // utf8, utf16, utf32, etc. - std::string name; -}; - -using PyObjectVec = std::vector; -using PyObjectVecPtr = std::shared_ptr; -using PyColumnVec = std::vector; -using PyColumnVecPtr = std::shared_ptr; - class PythonSource : public ISource { public: - PythonSource(py::object & data_source_, const Block & sample_block_, UInt64 max_block_size_, size_t stream_index, size_t num_streams); + PythonSource( + py::object & data_source_, + const Block & sample_block_, + PyColumnVecPtr column_cache, + size_t data_source_row_count, + size_t max_block_size_, + size_t stream_index, + size_t num_streams); ~PythonSource() override = default; @@ -52,18 +42,19 @@ class PythonSource : public ISource Block sample_block; PyColumnVecPtr column_cache; - + size_t data_source_row_count; const UInt64 max_block_size; // Caller will only pass stream index and total stream count // to the constructor, we need to calculate the start offset and end offset. const size_t stream_index; const size_t num_streams; size_t cursor; - size_t data_source_row_count; + Poco::Logger * logger = &Poco::Logger::get("TableFunctionPython"); ExternalResultDescription description; PyObjectVecPtr scanData(const py::object & data, const std::vector & col_names, size_t & cursor, size_t count); + void prepareColumnCache(Names & names, Columns & columns); Chunk scanDataToChunk(); void destory(PyObjectVecPtr & data); }; diff --git a/src/Storages/StoragePython.cpp b/src/Storages/StoragePython.cpp index c8a7e11d7fc..9168d59e0db 100644 --- a/src/Storages/StoragePython.cpp +++ b/src/Storages/StoragePython.cpp @@ -63,29 +63,19 @@ Pipe StoragePython::read( storage_snapshot->check(column_names); Block sample_block = prepareSampleBlock(column_names, storage_snapshot); - // check if string type column involved - bool has_string_column = false; - for (const auto & column_name : column_names) - { - if (sample_block.getByName(column_name).type->getName() == "String") - { - has_string_column = true; - break; - } - } // num_streams = 3; // for testing - // Converting Python str to ClickHouse String type will cost a lot of time. - // so if string column involved and not using PyReader return multiple streams. - if (has_string_column && !isInheritsFromPyReader(data_source)) - { - Pipes pipes; - for (size_t stream = 0; stream < num_streams; ++stream) - pipes.emplace_back(std::make_shared(data_source, sample_block, max_block_size, stream, num_streams)); - return Pipe::unitePipes(std::move(pipes)); - } - return Pipe(std::make_shared(data_source, sample_block, max_block_size, 0, 1)); + prepareColumnCache(column_names, sample_block.getColumns(), sample_block); + + if (isInheritsFromPyReader(data_source)) + return Pipe(std::make_shared(data_source, sample_block, column_cache, data_source_row_count, max_block_size, 0, 1)); + + Pipes pipes; + for (size_t stream = 0; stream < num_streams; ++stream) + pipes.emplace_back(std::make_shared( + data_source, sample_block, column_cache, data_source_row_count, max_block_size, stream, num_streams)); + return Pipe::unitePipes(std::move(pipes)); } Block StoragePython::prepareSampleBlock(const Names & column_names, const StorageSnapshotPtr & storage_snapshot) @@ -99,6 +89,38 @@ Block StoragePython::prepareSampleBlock(const Names & column_names, const Storag return sample_block; } +void StoragePython::prepareColumnCache(const Names & names, const Columns & columns, const Block & sample_block) +{ + // check column cache with GIL holded + py::gil_scoped_acquire acquire; + if (column_cache == nullptr) + { + // fill in the cache + column_cache = std::make_shared(columns.size()); + for (size_t i = 0; i < columns.size(); ++i) + { + const auto & col_name = names[i]; + auto & col = (*column_cache)[i]; + col.name = col_name; + try + { + py::object col_data = data_source[py::str(col_name)]; + col.buf = const_cast(tryGetPyArray(col_data, col.data, col.py_type, col.row_count)); + if (col.buf == nullptr) + throw Exception( + ErrorCodes::PY_EXCEPTION_OCCURED, "Convert to array failed for column {} type {}", col_name, col.py_type); + col.dest_type = sample_block.getByPosition(i).type; + data_source_row_count = col.row_count; + } + catch (const Exception & e) + { + LOG_ERROR(logger, "Error processing column {}: {}", col_name, e.what()); + throw; + } + } + } +} + ColumnsDescription StoragePython::getTableStructureFromData(py::object data_source) { if (!data_source) diff --git a/src/Storages/StoragePython.h b/src/Storages/StoragePython.h index b8288c297b3..219171fddd1 100644 --- a/src/Storages/StoragePython.h +++ b/src/Storages/StoragePython.h @@ -14,7 +14,7 @@ #include #include #include -#include "object.h" +#include namespace DB @@ -26,6 +26,7 @@ namespace ErrorCodes { extern const int UNKNOWN_FORMAT; extern const int NOT_IMPLEMENTED; +extern const int PY_EXCEPTION_OCCURED; } class PyReader { @@ -169,7 +170,10 @@ class StoragePython : public IStorage, public WithContext static ColumnsDescription getTableStructureFromData(py::object data_source); private: + void prepareColumnCache(const Names & names, const Columns & columns, const Block & sample_block); py::object data_source; + PyColumnVecPtr column_cache; + size_t data_source_row_count; Poco::Logger * logger = &Poco::Logger::get("StoragePython"); };