Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Be safer around none values in metadata statistics. #460

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions src/hats/io/parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ def read_row_group_fragments(metadata_file: str):
yield from frag.row_groups


def _nonemin(value1, value2):
"""Similar to numpy's nanmin, but excludes `None` values.

NB: If both values are `None`, this will still return `None`"""
if value1 is None:
return value2
if value2 is None:
return value1
return min(value1, value2)


def _nonemax(value1, value2):
"""Similar to numpy's nanmax, but excludes `None` values.

NB: If both values are `None`, this will still return `None`"""
if value1 is None:
return value2
if value2 is None:
return value1
return max(value1, value2)


def aggregate_column_statistics(
metadata_file: str | Path | UPath,
exclude_hats_columns: bool = True,
Expand Down Expand Up @@ -229,12 +251,18 @@ def aggregate_column_statistics(
if (len(include_columns) == 0 or name in include_columns)
and not (len(exclude_columns) > 0 and name in exclude_columns)
]
if not good_column_indexes:
return pd.DataFrame()
column_names = [column_names[i] for i in good_column_indexes]
extrema = [
(
first_row_group.column(col).statistics.min,
first_row_group.column(col).statistics.max,
first_row_group.column(col).statistics.null_count,
(None, None, 0)
if first_row_group.column(col).statistics is None
else (
first_row_group.column(col).statistics.min,
first_row_group.column(col).statistics.max,
first_row_group.column(col).statistics.null_count,
)
)
for col in good_column_indexes
]
Expand All @@ -243,25 +271,21 @@ def aggregate_column_statistics(
row_group = total_metadata.row_group(row_group_index)
row_stats = [
(
row_group.column(col).statistics.min,
row_group.column(col).statistics.max,
row_group.column(col).statistics.null_count,
(None, None, 0)
if row_group.column(col).statistics is None
else (
row_group.column(col).statistics.min,
row_group.column(col).statistics.max,
row_group.column(col).statistics.null_count,
)
)
for col in good_column_indexes
]
## This is annoying, but avoids extra copies, or none comparison.
extrema = [
(
(
min(extrema[col][0], row_stats[col][0])
if row_stats[col][0] is not None
else extrema[col][0]
),
(
max(extrema[col][1], row_stats[col][1])
if row_stats[col][1] is not None
else extrema[col][1]
),
(_nonemin(extrema[col][0], row_stats[col][0])),
(_nonemax(extrema[col][1], row_stats[col][1])),
extrema[col][2] + row_stats[col][2],
)
for col in range(0, len(good_column_indexes))
Expand Down
39 changes: 39 additions & 0 deletions tests/hats/io/test_parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from hats.io import file_io, paths
Expand Down Expand Up @@ -157,6 +158,44 @@ def test_aggregate_column_statistics(small_sky_order1_dir):
result_frame = aggregate_column_statistics(partition_info_file, include_columns=["ra", "dec"])
assert len(result_frame) == 2

result_frame = aggregate_column_statistics(partition_info_file, include_columns=["does", "not", "exist"])
assert len(result_frame) == 0


def test_aggregate_column_statistics_with_nulls(tmp_path):
file_io.make_directory(tmp_path / "dataset")

metadata_filename = tmp_path / "dataset" / "dataframe_01.parquet"
table_with_schema = pa.Table.from_arrays([[-1.0], [1.0]], names=["data", "Npix"])
pq.write_table(table_with_schema, metadata_filename)

icky_table = pa.Table.from_arrays([[2.0, None], [None, 6.0]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_02.parquet"
pq.write_table(icky_table, metadata_filename)

icky_table = pa.Table.from_arrays([[None], [None]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_00.parquet"
pq.write_table(icky_table, metadata_filename)

icky_table = pa.Table.from_arrays([[None, None], [None, None]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_03.parquet"
pq.write_table(icky_table, metadata_filename)

assert write_parquet_metadata(tmp_path, order_by_healpix=False) == 6

result_frame = aggregate_column_statistics(tmp_path / "dataset" / "_metadata", exclude_hats_columns=False)
assert len(result_frame) == 2

data_stats = result_frame.loc["data"]
assert data_stats["min_value"] == -1
assert data_stats["max_value"] == 2
assert data_stats["null_count"] == 4

data_stats = result_frame.loc["Npix"]
assert data_stats["min_value"] == 1
assert data_stats["max_value"] == 6
assert data_stats["null_count"] == 4


def test_row_group_stats(small_sky_dir):
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_dir)
Expand Down