Skip to content

Commit

Permalink
Support IGNORE_ERRORS when scanning from pyarrow/pandas (#4646)
Browse files Browse the repository at this point in the history
* Support IGNORE_ERRORS when scanning from pyarrow/pandas

* minor fix

---------

Co-authored-by: xiyang <[email protected]>
  • Loading branch information
royi-luo and andyfengHKU authored Jan 2, 2025
1 parent f19d605 commit 42cd8a0
Show file tree
Hide file tree
Showing 19 changed files with 161 additions and 37 deletions.
9 changes: 2 additions & 7 deletions src/binder/bound_scan_source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@ expression_vector BoundTableScanSource::getWarningColumns() const {
}
return warningDataExprs;
}

bool BoundTableScanSource::getIgnoreErrorsOption() const {
bool ignoreErrors = common::CopyConstants::DEFAULT_IGNORE_ERRORS;
if (type == common::ScanSourceType::FILE) {
auto bindData = info.bindData->constPtrCast<function::ScanBindData>();
ignoreErrors = bindData->config.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME,
ignoreErrors);
}
return ignoreErrors;
return info.bindData->getIgnoreErrorsOption();
}

} // namespace binder
Expand Down
11 changes: 11 additions & 0 deletions src/function/table/bind_data.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "function/table/bind_data.h"

#include "common/constants.h"

namespace kuzu {
namespace function {

Expand All @@ -14,5 +16,14 @@ std::vector<bool> TableFuncBindData::getColumnSkips() const {
return columnSkips;
}

bool TableFuncBindData::getIgnoreErrorsOption() const {
return common::CopyConstants::DEFAULT_IGNORE_ERRORS;
}

bool ScanBindData::getIgnoreErrorsOption() const {
return config.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME,
common::CopyConstants::DEFAULT_IGNORE_ERRORS);
}

} // namespace function
} // namespace kuzu
4 changes: 2 additions & 2 deletions src/include/common/vector/auxiliary_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace common {
class ValueVector;

// AuxiliaryBuffer holds data which is only used by the targeting dataType.
class AuxiliaryBuffer {
class KUZU_API AuxiliaryBuffer {
public:
virtual ~AuxiliaryBuffer() = default;

Expand Down Expand Up @@ -43,7 +43,7 @@ class StringAuxiliaryBuffer : public AuxiliaryBuffer {
std::unique_ptr<InMemOverflowBuffer> inMemOverflowBuffer;
};

class StructAuxiliaryBuffer : public AuxiliaryBuffer {
class KUZU_API StructAuxiliaryBuffer : public AuxiliaryBuffer {
public:
StructAuxiliaryBuffer(const LogicalType& type, storage::MemoryManager* memoryManager);

Expand Down
2 changes: 1 addition & 1 deletion src/include/function/export/export_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ using export_sink_t = std::function<void(ExportFuncSharedState&, ExportFuncLocal
using export_combine_t = std::function<void(ExportFuncSharedState&, ExportFuncLocalState&)>;
using export_finalize_t = std::function<void(ExportFuncSharedState&)>;

struct ExportFunction : public Function {
struct KUZU_API ExportFunction : public Function {
explicit ExportFunction(std::string name) : Function{std::move(name), {}} {}

ExportFunction(std::string name, export_init_local_t initLocal,
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct ScalarBindFuncInput {
using scalar_bind_func =
std::function<std::unique_ptr<FunctionBindData>(const ScalarBindFuncInput& bindInput)>;

struct Function {
struct KUZU_API Function {
std::string name;
std::vector<common::LogicalTypeID> parameterTypeIDs;
// Currently we only one variable-length function which is list creation. The expectation is
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/gds_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
namespace kuzu {
namespace function {

struct GDSFunction : public Function {
struct KUZU_API GDSFunction : public Function {
std::unique_ptr<GDSAlgorithm> gds;

GDSFunction() = default;
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/scalar_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using scalar_func_exec_t = std::function<void(
using scalar_func_select_t = std::function<bool(
const std::vector<std::shared_ptr<common::ValueVector>>&, common::SelectionVector&)>;

struct ScalarFunction : ScalarOrAggregateFunction {
struct KUZU_API ScalarFunction : public ScalarOrAggregateFunction {
scalar_func_exec_t execFunc = nullptr;
scalar_func_select_t selectFunc = nullptr;
scalar_func_compile_exec_t compileFunc = nullptr;
Expand Down
11 changes: 8 additions & 3 deletions src/include/function/table/bind_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class FileSystem;

namespace function {

struct TableFuncBindData {
struct KUZU_API TableFuncBindData {
binder::expression_vector columns;
// the last {numWarningDataColumns} columns are for temporary internal use
common::column_id_t numWarningDataColumns;
Expand All @@ -29,11 +29,12 @@ struct TableFuncBindData {
: columns{other.columns}, numWarningDataColumns(other.numWarningDataColumns),
cardinality{other.cardinality}, columnSkips{other.columnSkips},
columnPredicates{copyVector(other.columnPredicates)} {}
TableFuncBindData& operator=(const TableFuncBindData& other) = delete;
virtual ~TableFuncBindData() = default;

common::idx_t getNumColumns() const { return columns.size(); }
void setColumnSkips(std::vector<bool> skips) { columnSkips = std::move(skips); }
KUZU_API std::vector<bool> getColumnSkips() const;
std::vector<bool> getColumnSkips() const;

void setColumnPredicates(std::vector<storage::ColumnPredicateSet> predicates) {
columnPredicates = std::move(predicates);
Expand All @@ -42,6 +43,8 @@ struct TableFuncBindData {
return columnPredicates;
}

virtual bool getIgnoreErrorsOption() const;

virtual std::unique_ptr<TableFuncBindData> copy() const = 0;

template<class TARGET>
Expand All @@ -59,7 +62,7 @@ struct TableFuncBindData {
std::vector<storage::ColumnPredicateSet> columnPredicates;
};

struct ScanBindData : public TableFuncBindData {
struct KUZU_API ScanBindData : public TableFuncBindData {
common::ReaderConfig config;
main::ClientContext* context;

Expand All @@ -74,6 +77,8 @@ struct ScanBindData : public TableFuncBindData {
ScanBindData(const ScanBindData& other)
: TableFuncBindData{other}, config{other.config.copy()}, context{other.context} {}

bool getIgnoreErrorsOption() const override;

std::unique_ptr<TableFuncBindData> copy() const override {
return std::make_unique<ScanBindData>(*this);
}
Expand Down
4 changes: 2 additions & 2 deletions src/include/function/table/bind_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct ExtraTableFuncBindInput {
}
};

struct TableFuncBindInput {
struct KUZU_API TableFuncBindInput {
binder::expression_vector params;
optional_params_t optionalParams;
std::unique_ptr<ExtraTableFuncBindInput> extraInput = nullptr;
Expand All @@ -51,7 +51,7 @@ struct TableFuncBindInput {
T getLiteralVal(common::idx_t idx) const;
};

struct ExtraScanTableFuncBindInput : ExtraTableFuncBindInput {
struct KUZU_API ExtraScanTableFuncBindInput : ExtraTableFuncBindInput {
common::ReaderConfig config;
std::vector<std::string> expectedColumnNames;
std::vector<common::LogicalType> expectedColumnTypes;
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/table/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class FileSystem;

namespace function {

struct BaseScanSharedState : public TableFuncSharedState {
struct KUZU_API BaseScanSharedState : public TableFuncSharedState {
std::mutex lock;

virtual uint64_t getNumRows() const = 0;
Expand Down
6 changes: 3 additions & 3 deletions src/include/function/table/simple_table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ struct SimpleTableFuncMorsel {
}
};

struct SimpleTableFuncSharedState final : TableFuncSharedState {
struct KUZU_API SimpleTableFuncSharedState final : TableFuncSharedState {
common::offset_t maxOffset;
common::offset_t curOffset;
std::mutex mtx;

explicit SimpleTableFuncSharedState(common::offset_t maxOffset)
: maxOffset{maxOffset}, curOffset{0} {}

KUZU_API SimpleTableFuncMorsel getMorsel();
SimpleTableFuncMorsel getMorsel();
};

struct SimpleTableFuncBindData : TableFuncBindData {
struct KUZU_API SimpleTableFuncBindData : TableFuncBindData {
common::offset_t maxOffset;

explicit SimpleTableFuncBindData(common::offset_t maxOffset)
Expand Down
2 changes: 1 addition & 1 deletion src/include/function/table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace function {
struct TableFuncBindInput;
struct TableFuncBindData;

struct TableFuncSharedState {
struct KUZU_API TableFuncSharedState {
virtual ~TableFuncSharedState() = default;

template<class TARGET>
Expand Down
14 changes: 10 additions & 4 deletions tools/python_api/src_cpp/include/pandas/pandas_scan.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#pragma once

#include "function/scalar_function.h"
#include "function/table/bind_data.h"
#include "function/table/scan_functions.h"
#include "function/table_functions.h"
#include "pandas_bind.h"
#include "pybind_include.h"

namespace kuzu {

Expand Down Expand Up @@ -33,17 +31,24 @@ struct PandasScanFunction {
struct PandasScanFunctionData : public function::TableFuncBindData {
py::handle df;
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData;
common::ReaderConfig config;

PandasScanFunctionData(binder::expression_vector columns, py::handle df, uint64_t numRows,
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData)
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData,
common::ReaderConfig config)
: TableFuncBindData{std::move(columns), 0 /* numWarningDataColumns */, numRows}, df{df},
columnBindData{std::move(columnBindData)} {}
columnBindData{std::move(columnBindData)}, config(std::move(config)) {}

~PandasScanFunctionData() override {
py::gil_scoped_acquire acquire;
columnBindData.clear();
}

bool getIgnoreErrorsOption() const override {
return config.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME,
common::CopyConstants::DEFAULT_IGNORE_ERRORS);
}

std::vector<std::unique_ptr<PandasColumnBindData>> copyColumnBindData() const;

std::unique_ptr<function::TableFuncBindData> copy() const override {
Expand All @@ -56,6 +61,7 @@ struct PandasScanFunctionData : public function::TableFuncBindData {
for (const auto& i : other.columnBindData) {
columnBindData.push_back(i->copy());
}
config = other.config.copy();
}
};

Expand Down
16 changes: 11 additions & 5 deletions tools/python_api/src_cpp/include/pyarrow/pyarrow_scan.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <utility>

#include "common/arrow/arrow.h"
#include "function/scalar_function.h"
#include "function/table/bind_data.h"
Expand All @@ -12,6 +14,7 @@ namespace kuzu {
struct PyArrowScanConfig {
uint64_t skipNum;
uint64_t limitNum;
bool ignoreErrors;
explicit PyArrowScanConfig(const common::case_insensitive_map_t<common::Value>& options);
};

Expand All @@ -35,19 +38,22 @@ struct PyArrowTableScanSharedState final : public function::BaseScanSharedStateW
struct PyArrowTableScanFunctionData final : public function::TableFuncBindData {
std::shared_ptr<ArrowSchemaWrapper> schema;
std::vector<std::shared_ptr<ArrowArrayWrapper>> arrowArrayBatches;
bool ignoreErrors;

PyArrowTableScanFunctionData(binder::expression_vector columns,
std::shared_ptr<ArrowSchemaWrapper> schema,
std::vector<std::shared_ptr<ArrowArrayWrapper>> arrowArrayBatches, uint64_t numRows)
std::vector<std::shared_ptr<ArrowArrayWrapper>> arrowArrayBatches, uint64_t numRows,
bool ignoreErrors)
: TableFuncBindData{std::move(columns), 0 /* numWarningDataColumns */, numRows},
schema{std::move(schema)}, arrowArrayBatches{std::move(arrowArrayBatches)} {}
schema{std::move(schema)}, arrowArrayBatches{std::move(arrowArrayBatches)},
ignoreErrors(ignoreErrors) {}

~PyArrowTableScanFunctionData() override {}

bool getIgnoreErrorsOption() const override { return ignoreErrors; }

private:
PyArrowTableScanFunctionData(const PyArrowTableScanFunctionData& other)
: TableFuncBindData{other}, schema{other.schema},
arrowArrayBatches{other.arrowArrayBatches} {}
PyArrowTableScanFunctionData(const PyArrowTableScanFunctionData& other) = default;

public:
std::unique_ptr<function::TableFuncBindData> copy() const override {
Expand Down
9 changes: 8 additions & 1 deletion tools/python_api/src_cpp/pandas/pandas_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/exception/runtime.h"
#include "function/table/bind_input.h"
#include "numpy/numpy_scan.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "pyarrow/pyarrow_scan.h"
#include "pybind11/pytypes.h"
Expand All @@ -31,8 +32,9 @@ std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* /*context*/,
auto getFunc = df.attr("__getitem__");
auto numRows = py::len(getFunc(columns[0]));
auto returnColumns = input->binder->createVariables(names, returnTypes);
auto scanConfig = input->extraInput->constPtrCast<ExtraScanTableFuncBindInput>()->config.copy();
return std::make_unique<PandasScanFunctionData>(std::move(returnColumns), df, numRows,
std::move(columnBindData));
std::move(columnBindData), std::move(scanConfig));
}

bool sharedStateNext(const TableFuncBindData* /*bindData*/, PandasScanLocalState* localState,
Expand Down Expand Up @@ -119,6 +121,10 @@ static double progressFunc(TableFuncSharedState* sharedState) {
return static_cast<double>(pandasSharedState->numRowsRead) / pandasSharedState->numRows;
}

static void finalizeFunc(const processor::ExecutionContext* ctx, TableFuncSharedState*) {
ctx->clientContext->getWarningContextUnsafe().defaultPopulateAllWarnings(ctx->queryID);
}

function_set PandasScanFunction::getFunctionSet() {
function_set functionSet;
functionSet.push_back(getFunction().copy());
Expand All @@ -132,6 +138,7 @@ TableFunction PandasScanFunction::getFunction() {
function.initSharedStateFunc = initSharedState;
function.initLocalStateFunc = initLocalState;
function.progressFunc = progressFunc;
function.finalizeFunc = finalizeFunc;
return function;
}

Expand Down
17 changes: 15 additions & 2 deletions tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "common/arrow/arrow_converter.h"
#include "function/cast/functions/numeric_limits.h"
#include "function/table/bind_input.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "pyarrow/pyarrow_bind.h"
#include "pybind11/pytypes.h"
Expand All @@ -18,6 +19,7 @@ namespace kuzu {
PyArrowScanConfig::PyArrowScanConfig(const case_insensitive_map_t<Value>& options) {
skipNum = 0;
limitNum = NumericLimits<uint64_t>::maximum();
ignoreErrors = CopyConstants::DEFAULT_IGNORE_ERRORS;
for (const auto& i : options) {
if (i.first == "SKIP") {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
Expand All @@ -31,8 +33,14 @@ PyArrowScanConfig::PyArrowScanConfig(const case_insensitive_map_t<Value>& option
throw BinderException("LIMIT Option must be a positive integer literal.");
}
limitNum = i.second.val.int64Val;
} else if (i.first == CopyConstants::IGNORE_ERRORS_OPTION_NAME) {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::BOOL) {
throw BinderException("IGNORE_ERRORS Option must be a boolean.");
}
ignoreErrors = i.second.val.booleanVal;
} else {
throw BinderException(stringFormat("{} Option not recognized by pyArrow scanner."));
throw BinderException(
stringFormat("{} Option not recognized by pyArrow scanner.", i.first));
}
}
}
Expand Down Expand Up @@ -82,7 +90,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext*,

auto columns = input->binder->createVariables(names, returnTypes);
return std::make_unique<PyArrowTableScanFunctionData>(std::move(columns), std::move(schema),
arrowArrayBatches, numRows);
arrowArrayBatches, numRows, config.ignoreErrors);
}

ArrowArrayWrapper* PyArrowTableScanSharedState::getNextChunk() {
Expand Down Expand Up @@ -143,13 +151,18 @@ function_set PyArrowTableScanFunction::getFunctionSet() {
return functionSet;
}

static void finalizeFunc(const processor::ExecutionContext* ctx, TableFuncSharedState*) {
ctx->clientContext->getWarningContextUnsafe().defaultPopulateAllWarnings(ctx->queryID);
}

TableFunction PyArrowTableScanFunction::getFunction() {
auto function = TableFunction(name, std::vector{LogicalTypeID::POINTER});
function.tableFunc = tableFunc;
function.bindFunc = bindFunc;
function.initSharedStateFunc = initSharedState;
function.initLocalStateFunc = initLocalState;
function.progressFunc = progressFunc;
function.finalizeFunc = finalizeFunc;
return function;
}

Expand Down
Loading

0 comments on commit 42cd8a0

Please sign in to comment.