Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Allow to pass Arrow array as weights #6164

Merged
merged 56 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
ab2d5e2
Add Arrow support to Python API
borchero Jul 31, 2023
570ca64
Merge branch 'master' into arrow-support
borchero Aug 5, 2023
c21fab4
Fix lint
borchero Aug 5, 2023
2cd4302
Fix isort
borchero Aug 5, 2023
71957f6
[python-package] Allow to pass Arrow table as training data
borchero Aug 12, 2023
175fb13
Merge branch 'master' into arrow-support-training-data
borchero Aug 12, 2023
32dfb11
Remove change
borchero Aug 12, 2023
b5f0676
Implement JL comments
borchero Aug 12, 2023
cca3b37
Fix isort
borchero Aug 12, 2023
001139a
Remove testcase
borchero Aug 12, 2023
5861ca6
Adjust pyarrow version
borchero Aug 12, 2023
54d171c
Revert gitignore
borchero Aug 21, 2023
a87a15b
Fix lint
borchero Sep 5, 2023
8cda7cd
Merge branch 'master' into arrow-support-training-data
borchero Sep 5, 2023
6b4245a
Increase timeout for bdist_wheel build
borchero Sep 6, 2023
14a9326
Fix layout
borchero Sep 6, 2023
854f306
Add newline
borchero Sep 6, 2023
269582c
Fix typo
borchero Sep 11, 2023
9164040
Merge branch 'master' into arrow-support-training-data
borchero Sep 11, 2023
9a0a18d
Merge branch 'master' into arrow-support-training-data
borchero Sep 15, 2023
e5540cd
Remove arrow.py
borchero Sep 15, 2023
98997bf
Merge branch 'master' into arrow-support-training-data
jameslamb Sep 26, 2023
4a66cba
Merge branch 'master' into arrow-support-training-data
borchero Oct 12, 2023
f44421e
Fix cpp tests
borchero Oct 12, 2023
80b0aa3
Fix tests
borchero Oct 12, 2023
1869cfb
Fix omp parallel
borchero Oct 12, 2023
ba62bcc
Add missing <cmath> header
borchero Oct 12, 2023
db449e1
Fix cpplint
borchero Oct 12, 2023
3dab653
Disable arrow tests
borchero Oct 12, 2023
840cba9
Try fixing memory issue in tests
borchero Oct 13, 2023
19b210b
Try chunking in test
borchero Oct 13, 2023
059419d
Fix lint
borchero Oct 13, 2023
36e7bf4
Merge branch 'master' into arrow-support-training-data
borchero Oct 25, 2023
143a247
Implement review comments
borchero Oct 25, 2023
bb97817
Merge branch 'master' into arrow-support-training-data
jameslamb Oct 30, 2023
62431f2
Uninstall optional dependencies correctly
borchero Oct 30, 2023
34ee108
[python-package] Allow to pass Arrow array as labels
borchero Oct 30, 2023
90a2c1f
Fix lint
borchero Oct 30, 2023
6b65bcf
Fix lint
borchero Oct 30, 2023
ec33f75
WIP: [python-package] Allow to pass Arrow array as weights
borchero Oct 30, 2023
20a23b8
Fix lint
borchero Oct 30, 2023
7dbce53
Remove test
borchero Oct 30, 2023
0af7a7c
[python-package] Allow to pass Arrow table as training data
borchero Oct 30, 2023
45a67a6
Merge branch 'arrow-support-training-data' into arrow-support-labels
borchero Oct 30, 2023
80c12c0
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Oct 30, 2023
06bdce2
Merge branch 'master' into arrow-support-labels
borchero Nov 2, 2023
75a980e
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 2, 2023
f7c67e7
Implement guolinke's review
borchero Nov 7, 2023
91fade9
Merge branch 'master' into arrow-support-labels
jameslamb Nov 7, 2023
09ad33b
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 7, 2023
33f3e44
Merge branch 'master' into arrow-support-labels
borchero Nov 7, 2023
cd556da
Merge branch 'arrow-support-labels' into arrow-support-weights
borchero Nov 7, 2023
678ae7d
Use np_assert_array_equal
borchero Nov 7, 2023
5331202
Implement jameslamb's review comments
borchero Nov 8, 2023
74910d4
Merge branch 'master' into arrow-support-weights
jameslamb Nov 8, 2023
5041282
Merge branch 'master' into arrow-support-weights
jameslamb Nov 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
/*!
* \brief Set vector to a content in info.
* \note
* - \a label convert input datatype into ``float32``.
* - \a label and \a weight convert input datatype into ``float32``.
* \param handle Handle of dataset
* \param field_name Field name, can be \a label
* \param field_name Field name, can be \a label, \a weight
* \param n_chunks The number of Arrow arrays passed to this function
* \param chunks Pointer to the list of Arrow arrays
* \param schema Pointer to the schema of all Arrow arrays
Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Metadata {
void SetLabel(const ArrowChunkedArray& array);

void SetWeights(const label_t* weights, data_size_t len);
void SetWeights(const ArrowChunkedArray& array);

void SetQuery(const data_size_t* query, data_size_t len);

Expand Down Expand Up @@ -340,6 +341,9 @@ class Metadata {
void SetLabelsFromIterator(It first, It last);
/*! \brief Insert weights at the given index */
void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len);
/*! \brief Set weights from pointers to the first element and the end of an iterator. */
template <typename It>
void SetWeightsFromIterator(It first, It last);
/*! \brief Insert initial scores at the given index */
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
/*! \brief Insert queries at the given index */
Expand Down
29 changes: 20 additions & 9 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import scipy.sparse

from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat,
dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series)
dt_DataTable, pa_Array, pa_ChunkedArray, pa_compute, pa_Table, pd_CategoricalDtype, pd_DataFrame,
pd_Series)
from .libpath import find_lib_path

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,7 +116,9 @@
List[float],
List[int],
np.ndarray,
pd_Series
pd_Series,
pa_Array,
pa_ChunkedArray,
]
ZERO_THRESHOLD = 1e-35

Expand Down Expand Up @@ -1635,7 +1638,7 @@ def __init__(
Label of the data.
reference : Dataset or None, optional (default=None)
If this is Dataset for validation, training data should be used as reference.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data.
Expand Down Expand Up @@ -2415,7 +2418,7 @@ def create_valid(
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file.
label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Label of the data.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each instance. Weights should be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query data.
Expand Down Expand Up @@ -2830,19 +2833,27 @@ def set_weight(

Parameters
----------
weight : list, numpy 1-D array, pandas Series or None
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
Weight to be set for each data point. Weights should be non-negative.

Returns
-------
self : Dataset
Dataset with set weight.
"""
if weight is not None and np.all(weight == 1):
weight = None
# Check if the weight contains values other than one
if weight is not None:
if _is_pyarrow_array(weight):
if pa_compute.all(pa_compute.equal(weight, 1)).as_py():
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
weight = None
elif np.all(weight == 1):
weight = None
self.weight = weight

# Set field
if self._handle is not None and weight is not None:
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
if not _is_pyarrow_array(weight):
weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight')
self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self
Expand Down Expand Up @@ -4414,7 +4425,7 @@ def refit(

.. versionadded:: 4.0.0

weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
Weight for each ``data`` instance. Weights should be non-negative.

.. versionadded:: 4.0.0
Expand Down
7 changes: 7 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def __init__(self, *args, **kwargs):

"""pyarrow"""
try:
import pyarrow.compute as pa_compute
from pyarrow import Array as pa_Array
from pyarrow import ChunkedArray as pa_ChunkedArray
from pyarrow import Table as pa_Table
Expand Down Expand Up @@ -236,6 +237,12 @@ class arrow_cffi: # type: ignore
def __init__(self, *args, **kwargs):
pass

class pa_compute: # type: ignore
"""Dummy class for pyarrow.compute."""

all = None
equal = None

arrow_is_integer = None
arrow_is_floating = None

Expand Down
2 changes: 2 additions & 0 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
name = Common::Trim(name);
if (name == std::string("label") || name == std::string("target")) {
metadata_.SetLabel(ca);
} else if (name == std::string("weight") || name == std::string("weights")) {
metadata_.SetWeights(ca);
} else {
return false;
}
Expand Down
28 changes: 20 additions & 8 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,33 +450,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data
// CUDA is handled after all insertions are complete
}

void Metadata::SetWeights(const label_t* weights, data_size_t len) {
template <typename It>
void Metadata::SetWeightsFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (weights == nullptr || len == 0) {
// Clear weights on empty input
if (last - first == 0) {
weights_.clear();
num_weights_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("Length of weights is not same with #data");
if (num_data_ != last - first) {
Log::Fatal("Length of weights differs from the length of #data");
}
if (weights_.empty()) {
weights_.resize(num_data_);
}
if (weights_.empty()) { weights_.resize(num_data_); }
num_weights_ = num_data_;

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_weights_ >= 1024)
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = Common::AvoidInf(weights[i]);
weights_[i] = Common::AvoidInf(first[i]);
}
CalculateQueryWeights();
weight_load_from_file_ = false;

#ifdef USE_CUDA
if (cuda_metadata_ != nullptr) {
cuda_metadata_->SetWeights(weights_.data(), len);
cuda_metadata_->SetWeights(weights_.data(), weights_.size());
}
#endif // USE_CUDA
}

void Metadata::SetWeights(const label_t* weights, data_size_t len) {
SetWeightsFromIterator(weights, weights + len);
}

void Metadata::SetWeights(const ArrowChunkedArray& array) {
SetWeightsFromIterator(array.begin<label_t>(), array.end<label_t>());
}

void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) {
if (!weights) {
Log::Fatal("Passed null weights");
Expand Down
66 changes: 53 additions & 13 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import lightgbm as lgb

from .utils import np_assert_array_equal

# ----------------------------------------------------------------------------------------------- #
# UTILITIES #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]:
}


def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)


# ----------------------------------------------------------------------------------------------- #
# UNIT TESTS #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -103,6 +101,34 @@ def test_dataset_construct_fuzzy(
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")


# -------------------------------------------- FIELDS ------------------------------------------- #


def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
)
pandas_dataset.construct()

# Check for equality
for field in ("label", "weight"):
np_assert_array_equal(
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
)
np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True)
np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True)


# -------------------------------------------- LABELS ------------------------------------------- #


@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
Expand All @@ -129,17 +155,31 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type:
dataset.construct()

expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())
np_assert_array_equal(expected, dataset.get_label(), strict=True)


def test_dataset_construct_labels_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)
# ------------------------------------------- WEIGHTS ------------------------------------------- #

arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array)
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy())
pandas_dataset.construct()
def test_dataset_construct_weights_none():
data = generate_dummy_arrow_table()
weight = pa.array([1, 1, 1, 1, 1])
dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params())
dataset.construct()
assert dataset.get_weight() is None
borchero marked this conversation as resolved.
Show resolved Hide resolved
assert dataset.get_field("weight") is None


@pytest.mark.parametrize(
["array_type", "weight_data"],
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
dataset.construct()

assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label())
expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
np_assert_array_equal(expected, dataset.get_weight(), strict=True)