diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index a46f8332811a..fd337cbc7cbe 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -558,9 +558,9 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle, /*! * \brief Set vector to a content in info. * \note - * - \a label convert input datatype into ``float32``. + * - \a label and \a weight convert input datatype into ``float32``. * \param handle Handle of dataset - * \param field_name Field name, can be \a label + * \param field_name Field name, can be \a label, \a weight * \param n_chunks The number of Arrow arrays passed to this function * \param chunks Pointer to the list of Arrow arrays * \param schema Pointer to the schema of all Arrow arrays diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index 56bc7b841dc3..48c1bee804d7 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -113,6 +113,7 @@ class Metadata { void SetLabel(const ArrowChunkedArray& array); void SetWeights(const label_t* weights, data_size_t len); + void SetWeights(const ArrowChunkedArray& array); void SetQuery(const data_size_t* query, data_size_t len); @@ -340,6 +341,9 @@ class Metadata { void SetLabelsFromIterator(It first, It last); /*! \brief Insert weights at the given index */ void InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len); + /*! \brief Set weights from pointers to the first element and the end of an iterator. */ + template + void SetWeightsFromIterator(It first, It last); /*! \brief Insert initial scores at the given index */ void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size); /*! \brief Insert queries at the given index */ diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e8d8bd84cbe7..939842df3389 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -19,7 +19,8 @@ import scipy.sparse from .compat import (PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, arrow_is_floating, arrow_is_integer, concat, - dt_DataTable, pa_Array, pa_ChunkedArray, pa_Table, pd_CategoricalDtype, pd_DataFrame, pd_Series) + dt_DataTable, pa_Array, pa_ChunkedArray, pa_compute, pa_Table, pd_CategoricalDtype, pd_DataFrame, + pd_Series) from .libpath import find_lib_path if TYPE_CHECKING: @@ -115,7 +116,9 @@ List[float], List[int], np.ndarray, - pd_Series + pd_Series, + pa_Array, + pa_ChunkedArray, ] ZERO_THRESHOLD = 1e-35 @@ -1635,7 +1638,7 @@ def __init__( Label of the data. reference : Dataset or None, optional (default=None) If this is Dataset for validation, training data should be used as reference. - weight : list, numpy 1-D array, pandas Series or None, optional (default=None) + weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Weight for each instance. Weights should be non-negative. group : list, numpy 1-D array, pandas Series or None, optional (default=None) Group/query data. @@ -2415,7 +2418,7 @@ def create_valid( If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM) or a LightGBM Dataset binary file. label : list, numpy 1-D array, pandas Series / one-column DataFrame, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Label of the data. - weight : list, numpy 1-D array, pandas Series or None, optional (default=None) + weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Weight for each instance. Weights should be non-negative. group : list, numpy 1-D array, pandas Series or None, optional (default=None) Group/query data. @@ -2830,7 +2833,7 @@ def set_weight( Parameters ---------- - weight : list, numpy 1-D array, pandas Series or None + weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None Weight to be set for each data point. Weights should be non-negative. Returns @@ -2838,11 +2841,19 @@ def set_weight( self : Dataset Dataset with set weight. """ - if weight is not None and np.all(weight == 1): - weight = None + # Check if the weight contains values other than one + if weight is not None: + if _is_pyarrow_array(weight): + if pa_compute.all(pa_compute.equal(weight, 1)).as_py(): + weight = None + elif np.all(weight == 1): + weight = None self.weight = weight + + # Set field if self._handle is not None and weight is not None: - weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') + if not _is_pyarrow_array(weight): + weight = _list_to_1d_numpy(weight, dtype=np.float32, name='weight') self.set_field('weight', weight) self.weight = self.get_field('weight') # original values can be modified at cpp side return self @@ -4414,7 +4425,7 @@ def refit( .. versionadded:: 4.0.0 - weight : list, numpy 1-D array, pandas Series or None, optional (default=None) + weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None) Weight for each ``data`` instance. Weights should be non-negative. .. versionadded:: 4.0.0 diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 984972ed1ae3..dc48dbf792cf 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -197,6 +197,7 @@ def __init__(self, *args, **kwargs): """pyarrow""" try: + import pyarrow.compute as pa_compute from pyarrow import Array as pa_Array from pyarrow import ChunkedArray as pa_ChunkedArray from pyarrow import Table as pa_Table @@ -236,6 +237,12 @@ class arrow_cffi: # type: ignore def __init__(self, *args, **kwargs): pass + class pa_compute: # type: ignore + """Dummy class for pyarrow.compute.""" + + all = None + equal = None + arrow_is_integer = None arrow_is_floating = None diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index e78f8a6b696c..01eb41b71367 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -902,6 +902,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray name = Common::Trim(name); if (name == std::string("label") || name == std::string("target")) { metadata_.SetLabel(ca); + } else if (name == std::string("weight") || name == std::string("weights")) { + metadata_.SetWeights(ca); } else { return false; } diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 41f9e3bf43c6..ed4fb135e62a 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -450,33 +450,45 @@ void Metadata::InsertLabels(const label_t* labels, data_size_t start_index, data // CUDA is handled after all insertions are complete } -void Metadata::SetWeights(const label_t* weights, data_size_t len) { +template +void Metadata::SetWeightsFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // save to nullptr - if (weights == nullptr || len == 0) { + // Clear weights on empty input + if (last - first == 0) { weights_.clear(); num_weights_ = 0; return; } - if (num_data_ != len) { - Log::Fatal("Length of weights is not same with #data"); + if (num_data_ != last - first) { + Log::Fatal("Length of weights differs from the length of #data"); + } + if (weights_.empty()) { + weights_.resize(num_data_); } - if (weights_.empty()) { weights_.resize(num_data_); } num_weights_ = num_data_; #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static, 512) if (num_weights_ >= 1024) for (data_size_t i = 0; i < num_weights_; ++i) { - weights_[i] = Common::AvoidInf(weights[i]); + weights_[i] = Common::AvoidInf(first[i]); } CalculateQueryWeights(); weight_load_from_file_ = false; + #ifdef USE_CUDA if (cuda_metadata_ != nullptr) { - cuda_metadata_->SetWeights(weights_.data(), len); + cuda_metadata_->SetWeights(weights_.data(), weights_.size()); } #endif // USE_CUDA } +void Metadata::SetWeights(const label_t* weights, data_size_t len) { + SetWeightsFromIterator(weights, weights + len); +} + +void Metadata::SetWeights(const ArrowChunkedArray& array) { + SetWeightsFromIterator(array.begin(), array.end()); +} + void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, data_size_t len) { if (!weights) { Log::Fatal("Passed null weights"); diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 1dd270c8ec53..40482a904a62 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -9,6 +9,8 @@ import lightgbm as lgb +from .utils import np_assert_array_equal + # ----------------------------------------------------------------------------------------------- # # UTILITIES # # ----------------------------------------------------------------------------------------------- # @@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]: } -def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray): - assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs) - - # ----------------------------------------------------------------------------------------------- # # UNIT TESTS # # ----------------------------------------------------------------------------------------------- # @@ -103,6 +101,34 @@ def test_dataset_construct_fuzzy( assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") +# -------------------------------------------- FIELDS ------------------------------------------- # + + +def test_dataset_construct_fields_fuzzy(): + arrow_table = generate_random_arrow_table(3, 1000, 42) + arrow_labels = generate_random_arrow_array(1000, 42) + arrow_weights = generate_random_arrow_array(1000, 42) + + arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset( + arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy() + ) + pandas_dataset.construct() + + # Check for equality + for field in ("label", "weight"): + np_assert_array_equal( + arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True + ) + np_assert_array_equal(arrow_dataset.get_label(), pandas_dataset.get_label(), strict=True) + np_assert_array_equal(arrow_dataset.get_weight(), pandas_dataset.get_weight(), strict=True) + + +# -------------------------------------------- LABELS ------------------------------------------- # + + @pytest.mark.parametrize( ["array_type", "label_data"], [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], @@ -129,17 +155,31 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: dataset.construct() expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) - assert_arrays_equal(expected, dataset.get_label()) + np_assert_array_equal(expected, dataset.get_label(), strict=True) -def test_dataset_construct_labels_fuzzy(): - arrow_table = generate_random_arrow_table(3, 1000, 42) - arrow_array = generate_random_arrow_array(1000, 42) +# ------------------------------------------- WEIGHTS ------------------------------------------- # - arrow_dataset = lgb.Dataset(arrow_table, label=arrow_array) - arrow_dataset.construct() - pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), label=arrow_array.to_numpy()) - pandas_dataset.construct() +def test_dataset_construct_weights_none(): + data = generate_dummy_arrow_table() + weight = pa.array([1, 1, 1, 1, 1]) + dataset = lgb.Dataset(data, weight=weight, params=dummy_dataset_params()) + dataset.construct() + assert dataset.get_weight() is None + assert dataset.get_field("weight") is None + + +@pytest.mark.parametrize( + ["array_type", "weight_data"], + [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])], +) +@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()]) +def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any): + data = generate_dummy_arrow_table() + weights = array_type(weight_data, type=arrow_type) + dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params()) + dataset.construct() - assert_arrays_equal(arrow_dataset.get_label(), pandas_dataset.get_label()) + expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32) + np_assert_array_equal(expected, dataset.get_weight(), strict=True)