Skip to content

Commit

Permalink
Fix and implement review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
borchero committed Nov 15, 2023
1 parent ff5c9f8 commit 0f56ea0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
6 changes: 3 additions & 3 deletions src/io/metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,20 +510,20 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da
template <typename It>
void Metadata::SetQueriesFromIterator(It first, It last) {
std::lock_guard<std::mutex> lock(mutex_);
// Clear weights on empty input
// Clear query boundaries on empty input
if (last - first == 0) {
query_boundaries_.clear();
num_queries_ = 0;
return;
}

data_size_t sum = 0;
#pragma omp parallel for schedule(static) reduction(+:sum)
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum)
for (data_size_t i = 0; i < last - first; ++i) {
sum += first[i];
}
if (num_data_ != sum) {
Log::Fatal("Sum of query counts differs from the length of #data");
Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum);
}
num_queries_ = last - first;

Expand Down
19 changes: 8 additions & 11 deletions tests/python_package_test/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# coding: utf-8
import filecmp
from pathlib import Path
from typing import Any, Callable, Dict
from typing import Any, Dict

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -38,13 +37,13 @@ def generate_dummy_arrow_table() -> pa.Table:
return pa.Table.from_arrays([col1, col2], names=["a", "b"])


def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table:
def generate_random_arrow_table(num_columns, num_datapoints, seed) -> pa.Table:
columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)]
names = [f"col_{i}" for i in range(num_columns)]
return pa.Table.from_arrays(columns, names=names)


def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray:
def generate_random_arrow_array(num_datapoints, seed) -> pa.ChunkedArray:
generator = np.random.default_rng(seed)
data = generator.standard_normal(num_datapoints)

Expand Down Expand Up @@ -85,9 +84,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
(lambda: generate_random_arrow_table(100, 10000, 43), {}),
],
)
def test_dataset_construct_fuzzy(
tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any]
):
def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
arrow_table = arrow_table_fn()

arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
Expand All @@ -108,7 +105,7 @@ def test_dataset_construct_fields_fuzzy():
arrow_table = generate_random_arrow_table(3, 1000, 42)
arrow_labels = generate_random_arrow_array(1000, 42)
arrow_weights = generate_random_arrow_array(1000, 42)
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.uint8())
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())

arrow_dataset = lgb.Dataset(
arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
Expand Down Expand Up @@ -151,7 +148,7 @@ def test_dataset_construct_fields_fuzzy():
pa.float64(),
],
)
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
def test_dataset_construct_labels(array_type, label_data, arrow_type):
data = generate_dummy_arrow_table()
labels = array_type(label_data, type=arrow_type)
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
Expand All @@ -178,7 +175,7 @@ def test_dataset_construct_weights_none():
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
)
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
data = generate_dummy_arrow_table()
weights = array_type(weight_data, type=arrow_type)
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
Expand Down Expand Up @@ -207,7 +204,7 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type
pa.uint64(),
],
)
def test_dataset_construct_groups(array_type: Any, group_data: Any, arrow_type: Any):
def test_dataset_construct_groups(array_type, group_data, arrow_type):
data = generate_dummy_arrow_table()
groups = array_type(group_data, type=arrow_type)
dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
Expand Down

0 comments on commit 0f56ea0

Please sign in to comment.