Skip to content

Commit

Permalink
Be safer around none values in metadata statistics. (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Feb 28, 2025
1 parent af79a37 commit 757426f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
56 changes: 40 additions & 16 deletions src/hats/io/parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,28 @@ def read_row_group_fragments(metadata_file: str):
yield from frag.row_groups


def _nonemin(value1, value2):
"""Similar to numpy's nanmin, but excludes `None` values.
NB: If both values are `None`, this will still return `None`"""
if value1 is None:
return value2
if value2 is None:
return value1
return min(value1, value2)


def _nonemax(value1, value2):
"""Similar to numpy's nanmax, but excludes `None` values.
NB: If both values are `None`, this will still return `None`"""
if value1 is None:
return value2
if value2 is None:
return value1
return max(value1, value2)


def aggregate_column_statistics(
metadata_file: str | Path | UPath,
exclude_hats_columns: bool = True,
Expand Down Expand Up @@ -229,12 +251,18 @@ def aggregate_column_statistics(
if (len(include_columns) == 0 or name in include_columns)
and not (len(exclude_columns) > 0 and name in exclude_columns)
]
if not good_column_indexes:
return pd.DataFrame()
column_names = [column_names[i] for i in good_column_indexes]
extrema = [
(
first_row_group.column(col).statistics.min,
first_row_group.column(col).statistics.max,
first_row_group.column(col).statistics.null_count,
(None, None, 0)
if first_row_group.column(col).statistics is None
else (
first_row_group.column(col).statistics.min,
first_row_group.column(col).statistics.max,
first_row_group.column(col).statistics.null_count,
)
)
for col in good_column_indexes
]
Expand All @@ -243,25 +271,21 @@ def aggregate_column_statistics(
row_group = total_metadata.row_group(row_group_index)
row_stats = [
(
row_group.column(col).statistics.min,
row_group.column(col).statistics.max,
row_group.column(col).statistics.null_count,
(None, None, 0)
if row_group.column(col).statistics is None
else (
row_group.column(col).statistics.min,
row_group.column(col).statistics.max,
row_group.column(col).statistics.null_count,
)
)
for col in good_column_indexes
]
## This is annoying, but avoids extra copies, or none comparison.
extrema = [
(
(
min(extrema[col][0], row_stats[col][0])
if row_stats[col][0] is not None
else extrema[col][0]
),
(
max(extrema[col][1], row_stats[col][1])
if row_stats[col][1] is not None
else extrema[col][1]
),
(_nonemin(extrema[col][0], row_stats[col][0])),
(_nonemax(extrema[col][1], row_stats[col][1])),
extrema[col][2] + row_stats[col][2],
)
for col in range(0, len(good_column_indexes))
Expand Down
39 changes: 39 additions & 0 deletions tests/hats/io/test_parquet_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from hats.io import file_io, paths
Expand Down Expand Up @@ -157,6 +158,44 @@ def test_aggregate_column_statistics(small_sky_order1_dir):
result_frame = aggregate_column_statistics(partition_info_file, include_columns=["ra", "dec"])
assert len(result_frame) == 2

result_frame = aggregate_column_statistics(partition_info_file, include_columns=["does", "not", "exist"])
assert len(result_frame) == 0


def test_aggregate_column_statistics_with_nulls(tmp_path):
file_io.make_directory(tmp_path / "dataset")

metadata_filename = tmp_path / "dataset" / "dataframe_01.parquet"
table_with_schema = pa.Table.from_arrays([[-1.0], [1.0]], names=["data", "Npix"])
pq.write_table(table_with_schema, metadata_filename)

icky_table = pa.Table.from_arrays([[2.0, None], [None, 6.0]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_02.parquet"
pq.write_table(icky_table, metadata_filename)

icky_table = pa.Table.from_arrays([[None], [None]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_00.parquet"
pq.write_table(icky_table, metadata_filename)

icky_table = pa.Table.from_arrays([[None, None], [None, None]], schema=table_with_schema.schema)
metadata_filename = tmp_path / "dataset" / "dataframe_03.parquet"
pq.write_table(icky_table, metadata_filename)

assert write_parquet_metadata(tmp_path, order_by_healpix=False) == 6

result_frame = aggregate_column_statistics(tmp_path / "dataset" / "_metadata", exclude_hats_columns=False)
assert len(result_frame) == 2

data_stats = result_frame.loc["data"]
assert data_stats["min_value"] == -1
assert data_stats["max_value"] == 2
assert data_stats["null_count"] == 4

data_stats = result_frame.loc["Npix"]
assert data_stats["min_value"] == 1
assert data_stats["max_value"] == 6
assert data_stats["null_count"] == 4


def test_row_group_stats(small_sky_dir):
partition_info_file = paths.get_parquet_metadata_pointer(small_sky_dir)
Expand Down

0 comments on commit 757426f

Please sign in to comment.