Skip to content

Commit

Permalink
Extend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Dec 12, 2023
1 parent c013979 commit 69859df
Showing 1 changed file with 107 additions and 17 deletions.
124 changes: 107 additions & 17 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,37 @@ def generate_dummy_arrow_table() -> pa.Table:
return pa.Table.from_arrays([col1, col2], names=["a", "b"])


def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table:
columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)]
def generate_random_arrow_table(
num_columns: int,
num_datapoints: int,
seed: int,
generate_nulls: bool = True,
values: np.ndarray | None = None,
) -> pa.Table:
columns = [
generate_random_arrow_array(
num_datapoints, seed + i, generate_nulls=generate_nulls, values=values
)
for i in range(num_columns)
]
names = [f"col_{i}" for i in range(num_columns)]
return pa.Table.from_arrays(columns, names=names)


def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray:
def generate_random_arrow_array(
num_datapoints: int, seed: int, generate_nulls: bool = True, values: np.ndarray | None = None
) -> pa.ChunkedArray:
generator = np.random.default_rng(seed)
data = generator.standard_normal(num_datapoints)
data = (
generator.standard_normal(num_datapoints)
if values is None
else generator.choice(values, size=num_datapoints, replace=True)
)

# Set random nulls
indices = generator.choice(len(data), size=num_datapoints // 10)
data[indices] = None
if generate_nulls:
indices = generator.choice(len(data), size=num_datapoints // 10)
data[indices] = None

# Split data into <=2 random chunks
split_points = np.sort(generator.choice(np.arange(1, num_datapoints), 2, replace=False))
Expand Down Expand Up @@ -131,8 +149,8 @@ def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):

def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42, generate_nulls=False)
arrow_weights = generate_random_arrow_array(1000, 42, generate_nulls=False)
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())

arrow_dataset = lgb.Dataset(
Expand Down Expand Up @@ -264,9 +282,9 @@ def test_dataset_construct_init_scores_table():
data = generate_dummy_arrow_table()
init_scores = pa.Table.from_arrays(
[
generate_random_arrow_array(5, seed=1),
generate_random_arrow_array(5, seed=2),
generate_random_arrow_array(5, seed=3),
generate_random_arrow_array(5, seed=1, generate_nulls=False),
generate_random_arrow_array(5, seed=2, generate_nulls=False),
generate_random_arrow_array(5, seed=3, generate_nulls=False),
],
names=["a", "b", "c"],
)
Expand All @@ -281,12 +299,84 @@ def test_dataset_construct_init_scores_table():
# ------------------------------------------ PREDICTION ----------------------------------------- #


def test_predict():
@pytest.mark.parametrize(
("objective", "labels_fn", "groups_fn", "extra_params"),
[
(
"regression",
lambda: generate_random_arrow_array(10000, 43, generate_nulls=False),
lambda: None,
{},
),
(
"binary",
lambda: generate_random_arrow_array(
10000, 43, generate_nulls=False, values=np.arange(2)
),
lambda: None,
{},
),
(
"multiclass",
lambda: generate_random_arrow_array(
10000, 43, generate_nulls=False, values=np.arange(5)
),
lambda: None,
{"num_class": 5},
),
(
"cross_entropy",
lambda: generate_random_arrow_array(
10000, 43, generate_nulls=False, values=np.linspace(0, 1, num=50)
),
lambda: None,
{},
),
(
"lambdarank",
lambda: generate_random_arrow_array(
10000, 43, generate_nulls=False, values=np.arange(4)
),
lambda: np.array([1000, 2000, 3000, 4000]),
{},
),
],
)
@pytest.mark.parametrize("num_iteration", [None, 5])
@pytest.mark.parametrize("raw_score", [True, False])
@pytest.mark.parametrize("pred_leaf", [True, False])
@pytest.mark.parametrize("pred_contrib", [True, False])
def test_predict(
objective,
labels_fn,
groups_fn,
extra_params,
num_iteration,
raw_score,
pred_leaf,
pred_contrib,
):
data = generate_random_arrow_table(10, 10000, 42)
labels = generate_random_arrow_array(10000, 43)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
booster = lgb.train({}, dataset, num_boost_round=1)
dataset = lgb.Dataset(
data, label=labels_fn(), group=groups_fn(), params=dummy_dataset_params()
)
booster = lgb.train(
{
"objective": objective,
**extra_params,
},
dataset,
num_boost_round=10,
)

out_arrow = booster.predict(data)
out_pandas = booster.predict(data.to_pandas())
pred_kwargs = dict(
num_iteration=num_iteration,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
)
out_arrow = booster.predict(data, **pred_kwargs)
out_pandas = booster.predict(data.to_pandas(), **pred_kwargs)
print("ttt1", type(out_arrow))
print("ttt2", type(out_pandas))
np_assert_array_equal(out_arrow, out_pandas, strict=True)

0 comments on commit 69859df

Please sign in to comment.