diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 81a426a2456f..3e188f647138 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -9,6 +9,8 @@ import lightgbm as lgb +from .utils import np_assert_array_equal + # ----------------------------------------------------------------------------------------------- # # UTILITIES # # ----------------------------------------------------------------------------------------------- # @@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]: } -def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray): - assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs) - - # ----------------------------------------------------------------------------------------------- # # UNIT TESTS # # ----------------------------------------------------------------------------------------------- # @@ -103,6 +101,29 @@ def test_dataset_construct_fuzzy( assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt") +# -------------------------------------------- FIELDS ------------------------------------------- # + + +@pytest.mark.parametrize("field", ["label", "weight"]) +def test_dataset_construct_fields_fuzzy(field: str): + arrow_table = generate_random_arrow_table(3, 1000, 42) + arrow_array = generate_random_arrow_array(1000, 42) + + arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array}) + arrow_dataset.construct() + + pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()}) + pandas_dataset.construct() + + np_assert_array_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field)) + np_assert_array_equal( + getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")() + ) + + +# -------------------------------------------- LABELS ------------------------------------------- # + + @pytest.mark.parametrize( ["array_type", "label_data"], [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], @@ -129,24 +150,10 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: dataset.construct() expected = np.array([0, 1, 0, 0, 1], dtype=np.float32) - assert_arrays_equal(expected, dataset.get_label()) - + np_assert_array_equal(expected, dataset.get_label()) -@pytest.mark.parametrize("field", ["label", "weight"]) -def test_dataset_construct_fields_fuzzy(field: str): - arrow_table = generate_random_arrow_table(3, 1000, 42) - arrow_array = generate_random_arrow_array(1000, 42) - arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array}) - arrow_dataset.construct() - - pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()}) - pandas_dataset.construct() - - assert_arrays_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field)) - assert_arrays_equal( - getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")() - ) +# ------------------------------------------- WEIGHTS ------------------------------------------- # def test_dataset_construct_weights_none(): @@ -169,4 +176,4 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type dataset.construct() expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32) - assert_arrays_equal(expected, dataset.get_weight()) + np_assert_array_equal(expected, dataset.get_weight())