Skip to content

Commit

Permalink
Use np_assert_array_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Nov 7, 2023
1 parent cd556da commit 678ae7d
Showing 1 changed file with 28 additions and 21 deletions.
49 changes: 28 additions & 21 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import lightgbm as lgb

from .utils import np_assert_array_equal

# ----------------------------------------------------------------------------------------------- #
# UTILITIES #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -67,10 +69,6 @@ def dummy_dataset_params() -> Dict[str, Any]:
}


def assert_arrays_equal(lhs: np.ndarray, rhs: np.ndarray):
assert lhs.dtype == rhs.dtype and np.array_equal(lhs, rhs)


# ----------------------------------------------------------------------------------------------- #
# UNIT TESTS #
# ----------------------------------------------------------------------------------------------- #
Expand Down Expand Up @@ -103,6 +101,29 @@ def test_dataset_construct_fuzzy(
assert filecmp.cmp(tmp_path / "arrow.txt", tmp_path / "pandas.txt")


# -------------------------------------------- FIELDS ------------------------------------------- #


@pytest.mark.parametrize("field", ["label", "weight"])
def test_dataset_construct_fields_fuzzy(field: str):
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array})
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()})
pandas_dataset.construct()

np_assert_array_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field))
np_assert_array_equal(
getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")()
)


# -------------------------------------------- LABELS ------------------------------------------- #


@pytest.mark.parametrize(
["array_type", "label_data"],
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
Expand All @@ -129,24 +150,10 @@ def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type:
dataset.construct()

expected = np.array([0, 1, 0, 0, 1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_label())

np_assert_array_equal(expected, dataset.get_label())

@pytest.mark.parametrize("field", ["label", "weight"])
def test_dataset_construct_fields_fuzzy(field: str):
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_array = generate_random_arrow_array(1000, 42)

arrow_dataset = lgb.Dataset(arrow_table, **{field: arrow_array})
arrow_dataset.construct()

pandas_dataset = lgb.Dataset(arrow_table.to_pandas(), **{field: arrow_array.to_numpy()})
pandas_dataset.construct()

assert_arrays_equal(arrow_dataset.get_field(field), pandas_dataset.get_field(field))
assert_arrays_equal(
getattr(arrow_dataset, f"get_{field}")(), getattr(pandas_dataset, f"get_{field}")()
)
# ------------------------------------------- WEIGHTS ------------------------------------------- #


def test_dataset_construct_weights_none():
Expand All @@ -169,4 +176,4 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type
dataset.construct()

expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
assert_arrays_equal(expected, dataset.get_weight())
np_assert_array_equal(expected, dataset.get_weight())

0 comments on commit 678ae7d

Please sign in to comment.