Skip to content

Commit

Permalink
Merge all convert_and_insert and getTableStructureFromData v1
Browse files Browse the repository at this point in the history
  • Loading branch information
auxten committed May 7, 2024
1 parent 31d1497 commit 68e47fd
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 56 deletions.
110 changes: 62 additions & 48 deletions src/Processors/Sources/PythonSource.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
#include <Columns/ColumnDecimal.h>
#include <Columns/ColumnString.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeDecimalBase.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <Processors/Sources/PythonSource.h>
#include <Storages/StoragePython.h>
#include <base/Decimal.h>
#include <pybind11/gil.h>
#include <pybind11/pytypes.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>
#include <base/Decimal_fwd.h>
#include <base/types.h>

namespace DB
{
Expand All @@ -18,70 +25,50 @@ PythonSource::PythonSource(std::shared_ptr<PyReader> reader_, const Block & samp
}

template <typename T>
ColumnPtr convert_and_insert(py::object obj)
ColumnPtr convert_and_insert(py::object obj, UInt32 scale = 0)
{
auto column = ColumnVector<T>::create();
// if obj is a list
if (py::isinstance<py::list>(obj))
{
py::list list = obj.cast<py::list>();
for (auto && i : list)
column->insert(i.cast<T>());
// free the list
list.dec_ref();
}
else if (py::isinstance<py::array>(obj)) // if obj is a numpy array
{
py::array array = obj.cast<py::array>();
//chdb: array is a numpy array, so we can directly cast it to a vector?
for (auto && i : array)
column->insert(i.cast<T>());
// free the array, until we implement with zero copy
array.dec_ref();
}
MutableColumnPtr column;
if constexpr (std::is_same_v<T, DateTime64> || std::is_same_v<T, Decimal128> || std::is_same_v<T, Decimal256>)
column = ColumnDecimal<T>::create(0, scale);
else if constexpr (std::is_same_v<T, String>)
column = ColumnString::create();
else
{
throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unsupported type {}", obj.get_type().attr("__name__").cast<std::string>());
}
return column;
}
column = ColumnVector<T>::create();

template <>
ColumnPtr convert_and_insert<String>(py::object obj)
{
auto column = ColumnString::create();
if (py::isinstance<py::list>(obj))
{
py::list list = obj.cast<py::list>();
for (auto && i : list)
column->insert(i.cast<String>());
// free the list
column->insert(i.cast<T>());
list.dec_ref();
}
else if (py::isinstance<py::array>(obj))
{
py::array array = obj.cast<py::array>();
for (auto && i : array)
column->insert(i.cast<String>());
// free the array, until we implement with zero copy
column->insert(i.cast<T>());
array.dec_ref();
}
else
{
throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unsupported type {}", obj.get_type().attr("__name__").cast<std::string>());
throw Exception(
ErrorCodes::BAD_TYPE_OF_FIELD,
"Unsupported type {} for value {}",
obj.get_type().attr("__name__").cast<std::string>(),
py::str(obj).cast<std::string>());
}
return column;
}

Chunk PythonSource::generate()
{
size_t num_rows = 0;

std::vector<py::object> data;
try
{
// GIL is held when called from Python code. Release it to avoid deadlock
py::gil_scoped_release release;
std::vector<py::object> data = reader->read(description.sample_block.getNames(), max_block_size);
data = reader->read(description.sample_block.getNames(), max_block_size);

LOG_DEBUG(logger, "Read {} columns", data.size());
LOG_DEBUG(logger, "Need {} columns", description.sample_block.columns());
Expand Down Expand Up @@ -122,31 +109,58 @@ Chunk PythonSource::generate()
num_rows = py::len(data[i]);
const auto & column = data[i];
const auto & type = description.sample_block.getByPosition(i).type;
WhichDataType which(type);

if (type->equals(*std::make_shared<DataTypeUInt8>()))
if (which.isUInt8())
columns[i] = convert_and_insert<UInt8>(column);
else if (type->equals(*std::make_shared<DataTypeUInt16>()))
else if (which.isUInt16())
columns[i] = convert_and_insert<UInt16>(column);
else if (type->equals(*std::make_shared<DataTypeUInt32>()))
else if (which.isUInt32())
columns[i] = convert_and_insert<UInt32>(column);
else if (type->equals(*std::make_shared<DataTypeUInt64>()))
else if (which.isUInt64())
columns[i] = convert_and_insert<UInt64>(column);
else if (type->equals(*std::make_shared<DataTypeInt8>()))
else if (which.isUInt128())
columns[i] = convert_and_insert<UInt128>(column);
else if (which.isUInt256())
columns[i] = convert_and_insert<UInt256>(column);
else if (which.isInt8())
columns[i] = convert_and_insert<Int8>(column);
else if (type->equals(*std::make_shared<DataTypeInt16>()))
else if (which.isInt16())
columns[i] = convert_and_insert<Int16>(column);
else if (type->equals(*std::make_shared<DataTypeInt32>()))
else if (which.isInt32())
columns[i] = convert_and_insert<Int32>(column);
else if (type->equals(*std::make_shared<DataTypeInt64>()))
else if (which.isInt64())
columns[i] = convert_and_insert<Int64>(column);
else if (type->equals(*std::make_shared<DataTypeFloat32>()))
else if (which.isInt128())
columns[i] = convert_and_insert<Int128>(column);
else if (which.isInt256())
columns[i] = convert_and_insert<Int256>(column);
else if (which.isFloat32())
columns[i] = convert_and_insert<Float32>(column);
else if (type->equals(*std::make_shared<DataTypeFloat64>()))
else if (which.isFloat64())
columns[i] = convert_and_insert<Float64>(column);
else if (type->equals(*std::make_shared<DataTypeString>()))
else if (which.isDecimal128())
{
const auto & dtype = typeid_cast<const DataTypeDecimal<Decimal128> *>(type.get());
columns[i] = convert_and_insert<Decimal128>(column, dtype->getScale());
}
else if (which.isDecimal256())
{
const auto & dtype = typeid_cast<const DataTypeDecimal<Decimal256> *>(type.get());
columns[i] = convert_and_insert<Decimal256>(column, dtype->getScale());
}
else if (which.isDateTime())
columns[i] = convert_and_insert<UInt32>(column);
else if (which.isDateTime64())
columns[i] = convert_and_insert<DateTime64>(column);
else if (which.isString())
columns[i] = convert_and_insert<String>(column);
else
throw Exception(ErrorCodes::BAD_TYPE_OF_FIELD, "Unsupported type {}", type->getName());
throw Exception(
ErrorCodes::BAD_TYPE_OF_FIELD,
"Unsupported type {} for column {}",
type->getName(),
description.sample_block.getByPosition(i).name);
}
// Set data vector to empty to avoid trigger py::object destructor without GIL
// Note: we have already manually decremented the reference count of the list or array in `convert_and_insert` function
Expand Down
134 changes: 132 additions & 2 deletions src/Storages/StoragePython.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionsConversion.h>
#include <Interpreters/evaluateConstantExpression.h>
#include <Processors/Sources/PythonSource.h>
#include <Storages/ColumnsDescription.h>
#include <Storages/IStorage.h>
#include <Storages/StorageFactory.h>
#include <Storages/StoragePython.h>
#include <base/types.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <pybind11/stl.h>
#include <re2/re2.h>
#include <Poco/Logger.h>
#include <Common/Exception.h>
#include <Common/logger_useful.h>

#include <any>

Expand All @@ -22,8 +34,6 @@ extern const int LOGICAL_ERROR;
extern const int BAD_TYPE_OF_FIELD;
}

namespace py = pybind11;


StoragePython::StoragePython(
const StorageID & table_id_,
Expand Down Expand Up @@ -66,6 +76,126 @@ Block StoragePython::prepareSampleBlock(const Names & column_names, const Storag
return sample_block;
}

ColumnsDescription StoragePython::getTableStructureFromData(std::shared_ptr<PyReader> reader)
{
if (!reader)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Python reader not initialized");
auto schema = reader->getSchema();

auto * logger = &Poco::Logger::get("StoragePython");
if (logger->debug())
{
LOG_DEBUG(logger, "Schema content:");
for (const auto & item : schema)
LOG_DEBUG(logger, "Column: {}, Type: {}", String(item.first), String(item.second));
}

NamesAndTypesList names_and_types;

// Define regular expressions for different data types
RE2 pattern_int(R"(\bint(\d+))");
RE2 pattern_generic_int(R"(\bint\b|<class 'int'>)"); // Matches generic 'int'
RE2 pattern_uint(R"(\buint(\d+))");
RE2 pattern_float(R"(\b(float|double)(\d+))");
RE2 pattern_decimal128(R"(decimal128\((\d+),\s*(\d+)\))");
RE2 pattern_decimal256(R"(decimal256\((\d+),\s*(\d+)\))");
RE2 pattern_date32(R"(\bdate32\b)");
RE2 pattern_date64(R"(\bdate64\b)");
RE2 pattern_time32(R"(\btime32\b)");
RE2 pattern_time64_us(R"(\btime64\[us\]\b)");
RE2 pattern_time64_ns(R"(\btime64\[ns\]\b)");
RE2 pattern_string_binary(R"(\bstring\b|<class 'str'>|str|DataType\(string\)|DataType\(binary\)|dtype\[object_\]|dtype\('O'\))");

// Iterate through each pair of name and type string in the schema
for (const auto & [name, typeStr] : schema)
{
std::shared_ptr<IDataType> data_type;

std::string bits, precision, scale;
if (RE2::PartialMatch(typeStr, pattern_int, &bits))
{
if (bits == "8")
data_type = std::make_shared<DataTypeInt8>();
else if (bits == "16")
data_type = std::make_shared<DataTypeInt16>();
else if (bits == "32")
data_type = std::make_shared<DataTypeInt32>();
else if (bits == "64")
data_type = std::make_shared<DataTypeInt64>();
else if (bits == "128")
data_type = std::make_shared<DataTypeInt128>();
else if (bits == "256")
data_type = std::make_shared<DataTypeInt256>();
}
else if (RE2::PartialMatch(typeStr, pattern_uint, &bits))
{
if (bits == "8")
data_type = std::make_shared<DataTypeUInt8>();
else if (bits == "16")
data_type = std::make_shared<DataTypeUInt16>();
else if (bits == "32")
data_type = std::make_shared<DataTypeUInt32>();
else if (bits == "64")
data_type = std::make_shared<DataTypeUInt64>();
else if (bits == "128")
data_type = std::make_shared<DataTypeUInt128>();
else if (bits == "256")
data_type = std::make_shared<DataTypeUInt256>();
}
else if (RE2::PartialMatch(typeStr, pattern_generic_int))
{
data_type = std::make_shared<DataTypeInt64>(); // Default to 64-bit integers for generic 'int'
}
else if (RE2::PartialMatch(typeStr, pattern_float, &bits))
{
if (bits == "32")
data_type = std::make_shared<DataTypeFloat32>();
else if (bits == "64")
data_type = std::make_shared<DataTypeFloat64>();
}
else if (RE2::PartialMatch(typeStr, pattern_decimal128, &precision, &scale))
{
data_type = std::make_shared<DataTypeDecimal128>(std::stoi(precision), std::stoi(scale));
}
else if (RE2::PartialMatch(typeStr, pattern_decimal256, &precision, &scale))
{
data_type = std::make_shared<DataTypeDecimal256>(std::stoi(precision), std::stoi(scale));
}
else if (RE2::PartialMatch(typeStr, pattern_date32))
{
data_type = std::make_shared<DataTypeDate32>();
}
else if (RE2::PartialMatch(typeStr, pattern_date64))
{
data_type = std::make_shared<DataTypeDateTime64>(3); // date64 corresponds to DateTime64(3)
}
else if (RE2::PartialMatch(typeStr, pattern_time32))
{
data_type = std::make_shared<DataTypeDateTime>();
}
else if (RE2::PartialMatch(typeStr, pattern_time64_us))
{
data_type = std::make_shared<DataTypeDateTime64>(6); // time64[us] corresponds to DateTime64(6)
}
else if (RE2::PartialMatch(typeStr, pattern_time64_ns))
{
data_type = std::make_shared<DataTypeDateTime64>(9); // time64[ns] corresponds to DateTime64(9)
}
else if (RE2::PartialMatch(typeStr, pattern_string_binary))
{
data_type = std::make_shared<DataTypeString>();
}
else
{
throw Exception(ErrorCodes::TYPE_MISMATCH, "Unrecognized data type: {}", typeStr);
}

names_and_types.push_back({name, data_type});
}

return ColumnsDescription(names_and_types);
}

void registerStoragePython(StorageFactory & factory)
{
factory.registerStorage(
Expand Down
Loading

0 comments on commit 68e47fd

Please sign in to comment.