Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Allow to pass Arrow array as groups #6166

Merged
merged 72 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
ab2d5e2
Add Arrow support to Python API
borchero Jul 31, 2023
570ca64
Merge branch 'master' into arrow-support
borchero Aug 5, 2023
c21fab4
Fix lint
borchero Aug 5, 2023
2cd4302
Fix isort
borchero Aug 5, 2023
71957f6
[python-package] Allow to pass Arrow table as training data
borchero Aug 12, 2023
175fb13
Merge branch 'master' into arrow-support-training-data
borchero Aug 12, 2023
32dfb11
Remove change
borchero Aug 12, 2023
b5f0676
Implement JL comments
borchero Aug 12, 2023
cca3b37
Fix isort
borchero Aug 12, 2023
001139a
Remove testcase
borchero Aug 12, 2023
5861ca6
Adjust pyarrow version
borchero Aug 12, 2023
54d171c
Revert gitignore
borchero Aug 21, 2023
a87a15b
Fix lint
borchero Sep 5, 2023
8cda7cd
Merge branch 'master' into arrow-support-training-data
borchero Sep 5, 2023
6b4245a
Increase timeout for bdist_wheel build
borchero Sep 6, 2023
14a9326
Fix layout
borchero Sep 6, 2023
854f306
Add newline
borchero Sep 6, 2023
269582c
Fix typo
borchero Sep 11, 2023
9164040
Merge branch 'master' into arrow-support-training-data
borchero Sep 11, 2023
9a0a18d
Merge branch 'master' into arrow-support-training-data
borchero Sep 15, 2023
e5540cd
Remove arrow.py
borchero Sep 15, 2023
98997bf
Merge branch 'master' into arrow-support-training-data
jameslamb Sep 26, 2023
4a66cba
Merge branch 'master' into arrow-support-training-data
borchero Oct 12, 2023
f44421e
Fix cpp tests
borchero Oct 12, 2023
80b0aa3
Fix tests
borchero Oct 12, 2023
1869cfb
Fix omp parallel
borchero Oct 12, 2023
ba62bcc
Add missing <cmath> header
borchero Oct 12, 2023
db449e1
Fix cpplint
borchero Oct 12, 2023
3dab653
Disable arrow tests
borchero Oct 12, 2023
840cba9
Try fixing memory issue in tests
borchero Oct 13, 2023
19b210b
Try chunking in test
borchero Oct 13, 2023
059419d
Fix lint
borchero Oct 13, 2023
36e7bf4
Merge branch 'master' into arrow-support-training-data
borchero Oct 25, 2023
143a247
Implement review comments
borchero Oct 25, 2023
bb97817
Merge branch 'master' into arrow-support-training-data
jameslamb Oct 30, 2023
62431f2
Uninstall optional dependencies correctly
borchero Oct 30, 2023
34ee108
[python-package] Allow to pass Arrow array as labels
borchero Oct 30, 2023
90a2c1f
Fix lint
borchero Oct 30, 2023
6b65bcf
Fix lint
borchero Oct 30, 2023
ec33f75
WIP: [python-package] Allow to pass Arrow array as weights
borchero Oct 30, 2023
20a23b8
Fix lint
borchero Oct 30, 2023
ccdb0ba
Push
borchero Oct 30, 2023
7dbce53
Remove test
borchero Oct 30, 2023
ce69120
Merge branch 'arrow-support-weights' into arrow-support-groups
borchero Oct 30, 2023
e1593c2
Groups
borchero Oct 30, 2023
0af7a7c
[python-package] Allow to pass Arrow table as training data
borchero Oct 30, 2023
45a67a6
Merge branch 'arrow-support-training-data' into arrow-support-labels
borchero Oct 30, 2023
80c12c0
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Oct 30, 2023
221cba4
Merge branch 'arrow-support-weights' into arrow-support-groups
borchero Oct 30, 2023
15c8637
Fix isort
borchero Oct 30, 2023
06bdce2
Merge branch 'master' into arrow-support-labels
borchero Nov 2, 2023
75a980e
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 2, 2023
3d3ffb1
Merge branch 'arrow-support-weights' into arrow-support-groups
borchero Nov 2, 2023
f7c67e7
Implement guolinke's review
borchero Nov 7, 2023
91fade9
Merge branch 'master' into arrow-support-labels
jameslamb Nov 7, 2023
09ad33b
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 7, 2023
33f3e44
Merge branch 'master' into arrow-support-labels
borchero Nov 7, 2023
cd556da
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 7, 2023
678ae7d
Use np_assert_array_equal
borchero Nov 7, 2023
5331202
Implement jameslamb's review comments
borchero Nov 8, 2023
74910d4
Merge branch 'master' into arrow-support-weights
jameslamb Nov 8, 2023
5041282
Merge branch 'master' into arrow-support-weights
jameslamb Nov 13, 2023
04f0f21
Merge branch 'arrow-support-weights' into arrow-support-groups
borchero Nov 14, 2023
5e2baa1
Fix
borchero Nov 14, 2023
ff5c9f8
Merge branch 'master' into arrow-support-groups
borchero Nov 14, 2023
0f56ea0
Fix and implement review comments
borchero Nov 15, 2023
797cc3a
Fix
borchero Nov 15, 2023
8714625
Fix test
borchero Nov 16, 2023
acd916e
Fix
borchero Nov 16, 2023
c00b841
Merge branch 'master' into arrow-support-groups
borchero Nov 22, 2023
9b07160
Add tests for empty chunks
borchero Nov 22, 2023
79d050b
Fix lint
borchero Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
/*!
* \brief Set vector to a content in info.
* \note
* - \a group converts input datatype into ``int32``;
* - \a label and \a weight convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label, \a weight
* \param field_name Field name, can be \a label, \a weight, \a group
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class Metadata {
void SetWeights(const ArrowChunkedArray& array);

void SetQuery(const data_size_t* query, data_size_t len);
void SetQuery(const ArrowChunkedArray& array);

void SetPosition(const data_size_t* position, data_size_t len);

Expand Down Expand Up @@ -348,6 +349,9 @@ class Metadata {
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Insert queries at the given index */
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
template <typename It>
void SetQueriesFromIterator(It first, It last);
/*! \brief Filename of current data */
std::string data_filename_;
/*! \brief Number of data */
Expand Down
15 changes: 9 additions & 6 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
List[float],
List[int],
np.ndarray,
pd_Series
pd_Series,
pa_Array,
pa_ChunkedArray,
]
_LGBM_PositionType = Union[
np.ndarray,
Expand Down Expand Up @@ -1652,7 +1654,7 @@ def __init__(
If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
Expand Down Expand Up @@ -2432,7 +2434,7 @@ def create_valid(
Label of the data.
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
Expand Down Expand Up @@ -2889,7 +2891,7 @@ def set_group(

Parameters
----------
group : list, numpy 1-D array, pandas Series or None
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
Expand All @@ -2903,7 +2905,8 @@ def set_group(
"""
self.group = group
if self._handle is not None and group is not None:
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
if not _is_pyarrow_array(group):
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
self.set_field('group', group)
# original values can be modified at cpp side
constructed_group = self.get_field('group')
Expand Down Expand Up @@ -4431,7 +4434,7 @@ def refit(

.. versionadded:: 4.0.0

group : list, numpy 1-D array, pandas Series or None, optional (default=None)
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Group/query size for ``data``.
Only used in the learning-to-rank task.
sum(group) = n_samples.
Expand Down
2 changes: 2 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca);
} else if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(ca);
} else {
return false;
}
Expand Down
28 changes: 20 additions & 8 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da
// CUDA is handled after all insertions are complete
}

void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
template <typename It>
void Metadata::SetQueriesFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (query == nullptr || len == 0) {
// Clear query boundaries on empty input
if (last - first == 0) {
query_boundaries_.clear();
num_queries_ = 0;
return;
}

data_size_t sum = 0;
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum)
for (data_size_t i = 0; i < len; ++i) {
sum += query[i];
for (data_size_t i = 0; i < last - first; ++i) {
sum += first[i];
}
if (num_data_ != sum) {
Log::Fatal("Sum of query counts is not same with #data");
Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum);
}
num_queries_ = len;
num_queries_ = last - first;

query_boundaries_.resize(num_queries_ + 1);
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
query_boundaries_[i + 1] = query_boundaries_[i] + first[i];
}
CalculateQueryWeights();
query_load_from_file_ = false;

#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
if (query_weights_.size() > 0) {
Expand All @@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
#endif // USE_CUDA
}

void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
SetQueriesFromIterator(query, query + len);
}

void Metadata::SetQuery(const ArrowChunkedArray& array) {
SetQueriesFromIterator(array.begin<data_size_t>(), array.end<data_size_t>());
}

void Metadata::SetPosition(const data_size_t* positions, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
Expand Down
77 changes: 52 additions & 25 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
import filecmp
from pathlib import Path
from typing import Any, Callable, Dict
from typing import Any, Dict

import numpy as np
import pyarrow as pa
Expand All @@ -15,6 +14,21 @@
# UTILITIES #
# ----------------------------------------------------------------------------------------------- #

_INTEGER_TYPES = [
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
]
_FLOAT_TYPES = [
pa.float32(),
pa.float64(),
]


def generate_simple_arrow_table() -> pa.Table:
columns = [
Expand Down Expand Up @@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
(lambda: generate_random_arrow_table(100, 10000, 43), {}),
],
)
def test_dataset_construct_fuzzy(
tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any]
):
def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
arrow_table = arrow_table_fn()

arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
Expand All @@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
arrow_dataset = lgb.Dataset(
arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
arrow_table.to_pandas(),
label=arrow_labels.to_numpy(),
weight=arrow_weights.to_numpy(),
group=arrow_groups.to_numpy(),
)
pandas_dataset.construct()

# Check for equality
for field in ("label", "weight"):
for field in ("label", "weight", "group"):
np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
)
Expand All @@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy():
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
)
@pytest.mark.parametrize(
"arrow_type",
[
pa.int8(),
pa.int16(),
pa.int32(),
pa.int64(),
pa.uint8(),
pa.uint16(),
pa.uint32(),
pa.uint64(),
pa.float32(),
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
def test_dataset_construct_labels(array_type, label_data, arrow_type):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
Expand All @@ -175,11 +179,34 @@ def test_dataset_construct_weights_none():
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
dataset.construct()

expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight(), strict=True)


# -------------------------------------------- GROUPS ------------------------------------------- #


@pytest.mark.parametrize(
["array_type", "group_data"],
[
(pa.array, [2, 3]),
(pa.chunked_array, [[2], [3]]),
(pa.chunked_array, [[], [2, 3]]),
(pa.chunked_array, [[2], [], [3], []]),
],
)
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES)
def test_dataset_construct_groups(array_type, group_data, arrow_type):
data = generate_dummy_arrow_table()
groups = array_type(group_data, type=arrow_type)
dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
dataset.construct()

expected = np.array([0, 2, 5], dtype=np.int32)
np_assert_array_equal(expected, dataset.get_field("group"), strict=True)