Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support skip/limit options for pandas scan #4662

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pybind11_add_module(_kuzu
src_cpp/py_prepared_statement.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp
src_cpp/py_scan_config.cpp
src_cpp/py_udf.cpp
src_cpp/py_conversion.cpp
src_cpp/pyarrow/pyarrow_bind.cpp
Expand Down
21 changes: 9 additions & 12 deletions tools/python_api/src_cpp/include/pandas/pandas_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "function/table/scan_functions.h"
#include "function/table_functions.h"
#include "pandas_bind.h"
#include "py_scan_config.h"

namespace kuzu {

Expand All @@ -15,9 +16,10 @@ struct PandasScanLocalState final : public function::TableFuncLocalState {
};

struct PandasScanSharedState final : public function::BaseScanSharedStateWithNumRows {
explicit PandasScanSharedState(uint64_t numRows)
: BaseScanSharedStateWithNumRows{numRows}, numRowsRead{0} {}
PandasScanSharedState(uint64_t startRow, uint64_t numRows)
: BaseScanSharedStateWithNumRows{numRows}, startRow(startRow), numRowsRead{0} {}

uint64_t startRow;
uint64_t numRowsRead;
};

Expand All @@ -31,23 +33,19 @@ struct PandasScanFunction {
struct PandasScanFunctionData : public function::TableFuncBindData {
py::handle df;
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData;
common::FileScanInfo fileScanInfo;
PyScanConfig scanConfig;

PandasScanFunctionData(binder::expression_vector columns, py::handle df, uint64_t numRows,
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData,
common::FileScanInfo fileScanInfo)
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData, PyScanConfig scanConfig)
: TableFuncBindData{std::move(columns), 0 /* numWarningDataColumns */, numRows}, df{df},
columnBindData{std::move(columnBindData)}, fileScanInfo(std::move(fileScanInfo)) {}
columnBindData{std::move(columnBindData)}, scanConfig(scanConfig) {}

~PandasScanFunctionData() override {
py::gil_scoped_acquire acquire;
columnBindData.clear();
}

bool getIgnoreErrorsOption() const override {
return fileScanInfo.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME,
common::CopyConstants::DEFAULT_IGNORE_ERRORS);
}
bool getIgnoreErrorsOption() const override { return scanConfig.ignoreErrors; }

std::vector<std::unique_ptr<PandasColumnBindData>> copyColumnBindData() const;

Expand All @@ -57,11 +55,10 @@ struct PandasScanFunctionData : public function::TableFuncBindData {

private:
PandasScanFunctionData(const PandasScanFunctionData& other)
: TableFuncBindData{other}, df{other.df} {
: TableFuncBindData{other}, df{other.df}, scanConfig(other.scanConfig) {
for (const auto& i : other.columnBindData) {
columnBindData.push_back(i->copy());
}
fileScanInfo = other.fileScanInfo.copy();
}
};

Expand Down
16 changes: 16 additions & 0 deletions tools/python_api/src_cpp/include/py_scan_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "common/case_insensitive_map.h"
#include "common/types/value/value.h"

namespace kuzu {

struct PyScanConfig {
uint64_t skipNum;
uint64_t limitNum;
bool ignoreErrors;
explicit PyScanConfig(const common::case_insensitive_map_t<common::Value>& options,
uint64_t numRows);
};

} // namespace kuzu
7 changes: 0 additions & 7 deletions tools/python_api/src_cpp/include/pyarrow/pyarrow_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@

namespace kuzu {

struct PyArrowScanConfig {
uint64_t skipNum;
uint64_t limitNum;
bool ignoreErrors;
explicit PyArrowScanConfig(const common::case_insensitive_map_t<common::Value>& options);
};

struct PyArrowTableScanLocalState final : public function::TableFuncLocalState {
ArrowArrayWrapper* arrowArray;

Expand Down
19 changes: 12 additions & 7 deletions tools/python_api/src_cpp/pandas/pandas_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "numpy/numpy_scan.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "py_scan_config.h"
#include "pyarrow/pyarrow_scan.h"
#include "pybind11/pytypes.h"

Expand All @@ -32,10 +33,13 @@ std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* /*context*/,
auto getFunc = df.attr("__getitem__");
auto numRows = py::len(getFunc(columns[0]));
auto returnColumns = input->binder->createVariables(names, returnTypes);
auto scanConfig =
input->extraInput->constPtrCast<ExtraScanTableFuncBindInput>()->fileScanInfo.copy();
return std::make_unique<PandasScanFunctionData>(std::move(returnColumns), df, numRows,
std::move(columnBindData), std::move(scanConfig));
auto scanConfig = PyScanConfig{
input->extraInput->constPtrCast<ExtraScanTableFuncBindInput>()->fileScanInfo.options,
numRows};
KU_ASSERT(numRows >= scanConfig.skipNum);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm Do we have a check somewhere to ensure scanConfig.skipNum is always smaller than numRows? I'm concerned if this will lead to overflow under the case numRows < scanConfig.skipNum.

I would also add a test case on that if we don't have one already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added check in scan config constructor that bounds skipNum to the number of rows + added tests

return std::make_unique<PandasScanFunctionData>(std::move(returnColumns), df,
std::min(numRows - scanConfig.skipNum, scanConfig.limitNum), std::move(columnBindData),
scanConfig);
}

bool sharedStateNext(const TableFuncBindData* /*bindData*/, PandasScanLocalState* localState,
Expand All @@ -45,11 +49,11 @@ bool sharedStateNext(const TableFuncBindData* /*bindData*/, PandasScanLocalState
if (pandasSharedState->numRowsRead >= pandasSharedState->numRows) {
return false;
}
localState->start = pandasSharedState->numRowsRead;
localState->start = pandasSharedState->startRow + pandasSharedState->numRowsRead;
pandasSharedState->numRowsRead +=
std::min(pandasSharedState->numRows - pandasSharedState->numRowsRead,
CopyConstants::PANDAS_PARTITION_COUNT);
localState->end = pandasSharedState->numRowsRead;
localState->end = pandasSharedState->startRow + pandasSharedState->numRowsRead;
return true;
}

Expand All @@ -67,7 +71,8 @@ std::unique_ptr<TableFuncSharedState> initSharedState(const TableFunctionInitInp
}
// LCOV_EXCL_STOP
auto scanBindData = ku_dynamic_cast<PandasScanFunctionData*>(input.bindData);
return std::make_unique<PandasScanSharedState>(scanBindData->cardinality);
return std::make_unique<PandasScanSharedState>(scanBindData->scanConfig.skipNum,
scanBindData->cardinality);
}

void pandasBackendScanSwitch(PandasColumnBindData* bindData, uint64_t count, uint64_t offset,
Expand Down
39 changes: 39 additions & 0 deletions tools/python_api/src_cpp/py_scan_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "py_scan_config.h"

#include "common/constants.h"
#include "common/exception/binder.h"
#include "function/cast/functions/numeric_limits.h"

namespace kuzu {

PyScanConfig::PyScanConfig(const common::case_insensitive_map_t<common::Value>& options,
uint64_t numRows) {
skipNum = 0;
limitNum = function::NumericLimits<uint64_t>::maximum();
ignoreErrors = common::CopyConstants::DEFAULT_IGNORE_ERRORS;
for (const auto& i : options) {
if (i.first == "SKIP") {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw common::BinderException("SKIP Option must be a positive integer literal.");
}
skipNum = std::min(numRows, static_cast<uint64_t>(i.second.val.int64Val));
} else if (i.first == "LIMIT") {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw common::BinderException("LIMIT Option must be a positive integer literal.");
}
limitNum = i.second.val.int64Val;
} else if (i.first == common::CopyConstants::IGNORE_ERRORS_OPTION_NAME) {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::BOOL) {
throw common::BinderException("IGNORE_ERRORS Option must be a boolean.");
}
ignoreErrors = i.second.val.booleanVal;
} else {
throw common::BinderException(
common::stringFormat("{} Option not recognized by pyArrow scanner.", i.first));
}
}
}

} // namespace kuzu
32 changes: 2 additions & 30 deletions tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "function/table/bind_input.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "py_scan_config.h"
#include "pyarrow/pyarrow_bind.h"
#include "pybind11/pytypes.h"

Expand All @@ -16,35 +17,6 @@ using namespace kuzu::catalog;

namespace kuzu {

PyArrowScanConfig::PyArrowScanConfig(const case_insensitive_map_t<Value>& options) {
skipNum = 0;
limitNum = NumericLimits<uint64_t>::maximum();
ignoreErrors = CopyConstants::DEFAULT_IGNORE_ERRORS;
for (const auto& i : options) {
if (i.first == "SKIP") {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw BinderException("SKIP Option must be a positive integer literal.");
}
skipNum = i.second.val.int64Val;
} else if (i.first == "LIMIT") {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw BinderException("LIMIT Option must be a positive integer literal.");
}
limitNum = i.second.val.int64Val;
} else if (i.first == CopyConstants::IGNORE_ERRORS_OPTION_NAME) {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::BOOL) {
throw BinderException("IGNORE_ERRORS Option must be a boolean.");
}
ignoreErrors = i.second.val.booleanVal;
} else {
throw BinderException(
stringFormat("{} Option not recognized by pyArrow scanner.", i.first));
}
}
}

template<typename T>
static bool moduleIsLoaded() {
auto dict = pybind11::module_::import("sys").attr("modules");
Expand Down Expand Up @@ -73,7 +45,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext*,
}
auto numRows = py::len(table);
auto schema = Pyarrow::bind(table, returnTypes, names);
auto config = PyArrowScanConfig(scanInput->fileScanInfo.options);
auto config = PyScanConfig(scanInput->fileScanInfo.options, numRows);
// The following python operations are zero copy as defined in pyarrow docs.
if (config.skipNum != 0) {
table = table.attr("slice")(config.skipNum);
Expand Down
89 changes: 86 additions & 3 deletions tools/python_api/test/test_scan_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def test_scan_pandas(tmp_path: Path) -> None:
"INT32": np.array([-100, -200, -300, -400], dtype=np.int32),
"INT64": np.array([-1000, -2000, -3000, -4000], dtype=np.int64),
"FLOAT_32": np.array(
[-0.5199999809265137, float("nan"), -3.299999952316284, 4.400000095367432], dtype=np.float32
[-0.5199999809265137, float("nan"), -3.299999952316284, 4.400000095367432],
dtype=np.float32,
),
"FLOAT_64": np.array([5132.12321, 24.222, float("nan"), 4.444], dtype=np.float64),
"datetime_microseconds": np.array([
Expand Down Expand Up @@ -312,8 +313,18 @@ def test_pandas_scan_demo(tmp_path: Path) -> None:
"height_in_inch RETURN s"
).get_as_df()
assert len(result) == 2
assert result["s"][0] == {"ID": 0, "_id": {"offset": 0, "table": 0}, "_label": "student", "height": 70}
assert result["s"][1] == {"ID": 4, "_id": {"offset": 2, "table": 0}, "_label": "student", "height": 67}
assert result["s"][0] == {
"ID": 0,
"_id": {"offset": 0, "table": 0},
"_label": "student",
"height": 70,
}
assert result["s"][1] == {
"ID": 4,
"_id": {"offset": 2, "table": 0},
"_label": "student",
"height": 67,
}

conn.execute("CREATE NODE TABLE person(ID INT64, age UINT16, height UINT32, is_student BOOLean, PRIMARY KEY(ID))")
conn.execute("LOAD FROM person CREATE (p:person {ID: id, age: age, height: height, is_student: is_student})")
Expand Down Expand Up @@ -402,6 +413,78 @@ def test_copy_from_pandas_object(tmp_path: Path) -> None:
assert result.has_next() is False


def test_copy_from_pandas_object_skip(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Zhang", "50"]
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False
df = pd.DataFrame({"f": ["Adam", "Noura"], "t": ["Zhang", "Zhang"]})
conn.execute("CREATE REL TABLE Knows(FROM Person TO Person);")
conn.execute("COPY Knows FROM df(SKIP=1)")
result = conn.execute("match (p:Person)-[]->(:Person {name: 'Zhang'}) return p.*")
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False


def test_copy_from_pandas_object_limit(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(LIMIT=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.get_next() == ["Karissa", "40"]
assert result.has_next() is False
df = pd.DataFrame({"f": ["Adam", "Zhang"], "t": ["Karissa", "Karissa"]})
conn.execute("CREATE REL TABLE Knows(FROM Person TO Person);")
conn.execute("COPY Knows FROM df(LIMIT=1)")
result = conn.execute("match (p:Person)-[]->(:Person {name: 'Karissa'}) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.has_next() is False


def test_copy_from_pandas_object_skip_and_limit(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=1, LIMIT=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Karissa", "40"]
assert result.get_next() == ["Zhang", "50"]
assert result.has_next() is False


def test_copy_from_pandas_object_skip_bounds_check(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=10);")
result = conn.execute("match (p:Person) return p.*")
assert result.has_next() is False


def test_copy_from_pandas_object_limit_bounds_check(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(LIMIT=10);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.get_next() == ["Karissa", "40"]
assert result.get_next() == ["Zhang", "50"]
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False


def test_copy_from_pandas_date(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
Expand Down
12 changes: 12 additions & 0 deletions tools/python_api/test/test_scan_pandas_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,18 @@ def test_pyarrow_skip_limit(conn_db_readonly: ConnDB) -> None:
assert result["col1"].to_pylist() == expected["col1"].to_pylist()
assert result["col2"].to_pylist() == expected["col2"].to_pylist()

# skip bounds check
result = conn.execute("LOAD FROM df (SKIP=500000, LIMIT=5000) RETURN * ORDER BY index").get_as_arrow()
assert len(result) == 0

# limit bounds check
result = conn.execute("LOAD FROM df (SKIP=0, LIMIT=500000) RETURN * ORDER BY index").get_as_arrow()
expected = pa.Table.from_pandas(df)
assert result["index"].to_pylist() == expected["index"].to_pylist()
assert result["col0"].to_pylist() == expected["col0"].to_pylist()
assert result["col1"].to_pylist() == expected["col1"].to_pylist()
assert result["col2"].to_pylist() == expected["col2"].to_pylist()


def test_pyarrow_invalid_skip_limit(conn_db_readonly: ConnDB) -> None:
conn, db = conn_db_readonly
Expand Down
Loading