diff --git a/src/io/metadata.cpp b/src/io/metadata.cpp index 248feebabd29..d94b0ed3f2f7 100644 --- a/src/io/metadata.cpp +++ b/src/io/metadata.cpp @@ -510,7 +510,7 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da template void Metadata::SetQueriesFromIterator(It first, It last) { std::lock_guard lock(mutex_); - // Clear weights on empty input + // Clear query boundaries on empty input if (last - first == 0) { query_boundaries_.clear(); num_queries_ = 0; @@ -518,12 +518,12 @@ void Metadata::SetQueriesFromIterator(It first, It last) { } data_size_t sum = 0; - #pragma omp parallel for schedule(static) reduction(+:sum) + #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum) for (data_size_t i = 0; i < last - first; ++i) { sum += first[i]; } if (num_data_ != sum) { - Log::Fatal("Sum of query counts differs from the length of #data"); + Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum); } num_queries_ = last - first; diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index cd9fa4dcf678..dc482ecfdc2d 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -1,7 +1,6 @@ # coding: utf-8 import filecmp -from pathlib import Path -from typing import Any, Callable, Dict +from typing import Any, Dict import numpy as np import pyarrow as pa @@ -38,13 +37,13 @@ def generate_dummy_arrow_table() -> pa.Table: return pa.Table.from_arrays([col1, col2], names=["a", "b"]) -def generate_random_arrow_table(num_columns: int, num_datapoints: int, seed: int) -> pa.Table: +def generate_random_arrow_table(num_columns, num_datapoints, seed) -> pa.Table: columns = [generate_random_arrow_array(num_datapoints, seed + i) for i in range(num_columns)] names = [f"col_{i}" for i in range(num_columns)] return pa.Table.from_arrays(columns, names=names) -def generate_random_arrow_array(num_datapoints: int, seed: int) -> pa.ChunkedArray: +def generate_random_arrow_array(num_datapoints, seed) -> pa.ChunkedArray: generator = np.random.default_rng(seed) data = generator.standard_normal(num_datapoints) @@ -85,9 +84,7 @@ def dummy_dataset_params() -> Dict[str, Any]: (lambda: generate_random_arrow_table(100, 10000, 43), {}), ], ) -def test_dataset_construct_fuzzy( - tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any] -): +def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params): arrow_table = arrow_table_fn() arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params) @@ -108,7 +105,7 @@ def test_dataset_construct_fields_fuzzy(): arrow_table = generate_random_arrow_table(3, 1000, 42) arrow_labels = generate_random_arrow_array(1000, 42) arrow_weights = generate_random_arrow_array(1000, 42) - arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.uint8()) + arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32()) arrow_dataset = lgb.Dataset( arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups @@ -151,7 +148,7 @@ def test_dataset_construct_fields_fuzzy(): pa.float64(), ], ) -def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any): +def test_dataset_construct_labels(array_type, label_data, arrow_type): data = generate_dummy_arrow_table() labels = array_type(label_data, type=arrow_type) dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params()) @@ -178,7 +175,7 @@ def test_dataset_construct_weights_none(): [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])], ) @pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()]) -def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any): +def test_dataset_construct_weights(array_type, weight_data, arrow_type): data = generate_dummy_arrow_table() weights = array_type(weight_data, type=arrow_type) dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params()) @@ -207,7 +204,7 @@ def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type pa.uint64(), ], ) -def test_dataset_construct_groups(array_type: Any, group_data: Any, arrow_type: Any): +def test_dataset_construct_groups(array_type, group_data, arrow_type): data = generate_dummy_arrow_table() groups = array_type(group_data, type=arrow_type) dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())