diff --git a/src/hats/io/parquet_metadata.py b/src/hats/io/parquet_metadata.py index 3b1f4ece..60e88c2b 100644 --- a/src/hats/io/parquet_metadata.py +++ b/src/hats/io/parquet_metadata.py @@ -192,6 +192,28 @@ def read_row_group_fragments(metadata_file: str): yield from frag.row_groups +def _nonemin(value1, value2): + """Similar to numpy's nanmin, but excludes `None` values. + + NB: If both values are `None`, this will still return `None`""" + if value1 is None: + return value2 + if value2 is None: + return value1 + return min(value1, value2) + + +def _nonemax(value1, value2): + """Similar to numpy's nanmax, but excludes `None` values. + + NB: If both values are `None`, this will still return `None`""" + if value1 is None: + return value2 + if value2 is None: + return value1 + return max(value1, value2) + + def aggregate_column_statistics( metadata_file: str | Path | UPath, exclude_hats_columns: bool = True, @@ -229,12 +251,18 @@ def aggregate_column_statistics( if (len(include_columns) == 0 or name in include_columns) and not (len(exclude_columns) > 0 and name in exclude_columns) ] + if not good_column_indexes: + return pd.DataFrame() column_names = [column_names[i] for i in good_column_indexes] extrema = [ ( - first_row_group.column(col).statistics.min, - first_row_group.column(col).statistics.max, - first_row_group.column(col).statistics.null_count, + (None, None, 0) + if first_row_group.column(col).statistics is None + else ( + first_row_group.column(col).statistics.min, + first_row_group.column(col).statistics.max, + first_row_group.column(col).statistics.null_count, + ) ) for col in good_column_indexes ] @@ -243,25 +271,21 @@ def aggregate_column_statistics( row_group = total_metadata.row_group(row_group_index) row_stats = [ ( - row_group.column(col).statistics.min, - row_group.column(col).statistics.max, - row_group.column(col).statistics.null_count, + (None, None, 0) + if row_group.column(col).statistics is None + else ( + row_group.column(col).statistics.min, + row_group.column(col).statistics.max, + row_group.column(col).statistics.null_count, + ) ) for col in good_column_indexes ] ## This is annoying, but avoids extra copies, or none comparison. extrema = [ ( - ( - min(extrema[col][0], row_stats[col][0]) - if row_stats[col][0] is not None - else extrema[col][0] - ), - ( - max(extrema[col][1], row_stats[col][1]) - if row_stats[col][1] is not None - else extrema[col][1] - ), + (_nonemin(extrema[col][0], row_stats[col][0])), + (_nonemax(extrema[col][1], row_stats[col][1])), extrema[col][2] + row_stats[col][2], ) for col in range(0, len(good_column_indexes)) diff --git a/tests/hats/io/test_parquet_metadata.py b/tests/hats/io/test_parquet_metadata.py index 89a37693..03d3be53 100644 --- a/tests/hats/io/test_parquet_metadata.py +++ b/tests/hats/io/test_parquet_metadata.py @@ -4,6 +4,7 @@ import pandas as pd import pyarrow as pa +import pyarrow.parquet as pq import pytest from hats.io import file_io, paths @@ -157,6 +158,44 @@ def test_aggregate_column_statistics(small_sky_order1_dir): result_frame = aggregate_column_statistics(partition_info_file, include_columns=["ra", "dec"]) assert len(result_frame) == 2 + result_frame = aggregate_column_statistics(partition_info_file, include_columns=["does", "not", "exist"]) + assert len(result_frame) == 0 + + +def test_aggregate_column_statistics_with_nulls(tmp_path): + file_io.make_directory(tmp_path / "dataset") + + metadata_filename = tmp_path / "dataset" / "dataframe_01.parquet" + table_with_schema = pa.Table.from_arrays([[-1.0], [1.0]], names=["data", "Npix"]) + pq.write_table(table_with_schema, metadata_filename) + + icky_table = pa.Table.from_arrays([[2.0, None], [None, 6.0]], schema=table_with_schema.schema) + metadata_filename = tmp_path / "dataset" / "dataframe_02.parquet" + pq.write_table(icky_table, metadata_filename) + + icky_table = pa.Table.from_arrays([[None], [None]], schema=table_with_schema.schema) + metadata_filename = tmp_path / "dataset" / "dataframe_00.parquet" + pq.write_table(icky_table, metadata_filename) + + icky_table = pa.Table.from_arrays([[None, None], [None, None]], schema=table_with_schema.schema) + metadata_filename = tmp_path / "dataset" / "dataframe_03.parquet" + pq.write_table(icky_table, metadata_filename) + + assert write_parquet_metadata(tmp_path, order_by_healpix=False) == 6 + + result_frame = aggregate_column_statistics(tmp_path / "dataset" / "_metadata", exclude_hats_columns=False) + assert len(result_frame) == 2 + + data_stats = result_frame.loc["data"] + assert data_stats["min_value"] == -1 + assert data_stats["max_value"] == 2 + assert data_stats["null_count"] == 4 + + data_stats = result_frame.loc["Npix"] + assert data_stats["min_value"] == 1 + assert data_stats["max_value"] == 6 + assert data_stats["null_count"] == 4 + def test_row_group_stats(small_sky_dir): partition_info_file = paths.get_parquet_metadata_pointer(small_sky_dir)