Skip to content

Commit

Permalink
Drop duplicates in leaf files during SOAP. (#250)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Mar 6, 2024
1 parent ea60042 commit 9d0bfc4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/hipscat_import/soap/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@ def _count_joins_for_object(source_data, source_pixel, object_pixel, soap_args):
prepared_data = pd.DataFrame(
data={
"object_id": joined_data.index.values,
"source_id": joined_data[soap_args.source_id_column],
"source_id": (
joined_data.index.values
if soap_args.source_object_id_column == soap_args.source_id_column
else joined_data[soap_args.source_id_column]
),
"Norder": np.full(rows_written, fill_value=object_pixel.order, dtype=np.uint8),
"Dir": np.full(rows_written, fill_value=object_pixel.dir, dtype=np.uint64),
"Npix": np.full(rows_written, fill_value=object_pixel.pixel, dtype=np.uint64),
"join_Norder": np.full(rows_written, fill_value=source_pixel.order, dtype=np.uint8),
"join_Dir": np.full(rows_written, fill_value=source_pixel.dir, dtype=np.uint64),
"join_Npix": np.full(rows_written, fill_value=source_pixel.pixel, dtype=np.uint64),
},
)
).drop_duplicates()

# Write to parquet file.
pixel_dir = get_pixel_cache_directory(soap_args.tmp_path, object_pixel)
Expand Down Expand Up @@ -94,7 +98,7 @@ def count_joins(soap_args: SoapArguments, source_pixel: HealpixPixel, object_pix
pixel_order=source_pixel.order,
pixel_number=source_pixel.pixel,
)
if soap_args.write_leaf_files:
if soap_args.write_leaf_files and soap_args.source_object_id_column != soap_args.source_id_column:
read_columns = [soap_args.source_object_id_column, soap_args.source_id_column]
else:
read_columns = [soap_args.source_object_id_column]
Expand Down
68 changes: 68 additions & 0 deletions tests/hipscat_import/soap/test_run_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,71 @@ def test_object_to_source_with_leaves(

metadata_filename = os.path.join(small_sky_soap_args.catalog_path, "catalog_info.json")
assert_text_file_matches(expected_lines, metadata_filename)


@pytest.mark.dask
def test_object_to_source_with_leaves_drop_duplicates(
dask_client, tmp_path, small_sky_object_catalog, small_sky_source_catalog, assert_text_file_matches
):
"""Test creating association between object and source catalogs."""
small_sky_soap_args = SoapArguments(
object_catalog_dir=small_sky_object_catalog,
object_id_column="id",
source_catalog_dir=small_sky_source_catalog,
source_object_id_column="object_id",
output_path=tmp_path,
overwrite=True,
progress_bar=False,
write_leaf_files=True,
source_id_column="object_id",
output_artifact_name="small_sky_object_to_source",
)
runner.run(small_sky_soap_args, dask_client)

## Check that the association data can be parsed as a valid association catalog.
catalog = AssociationCatalog.read_from_hipscat(small_sky_soap_args.catalog_path)
assert catalog.on_disk
assert catalog.catalog_path == small_sky_soap_args.catalog_path
assert len(catalog.get_join_pixels()) == 14
assert catalog.catalog_info.total_rows == 148
assert catalog.catalog_info.contains_leaf_files

parquet_file_name = os.path.join(small_sky_soap_args.catalog_path, "Norder=0", "Dir=0", "Npix=11.parquet")
assert os.path.exists(parquet_file_name), f"file not found [{parquet_file_name}]"

parquet_file = pq.ParquetFile(parquet_file_name)
assert parquet_file.metadata.num_row_groups == 14
assert parquet_file.metadata.num_rows == 148
assert parquet_file.metadata.num_columns == 8

exepcted_schema = pa.schema(
[
pa.field("object_id", pa.int64()),
pa.field("source_id", pa.int64()),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("join_Norder", pa.uint8()),
pa.field("join_Dir", pa.uint64()),
pa.field("join_Npix", pa.uint64()),
]
)
assert parquet_file.metadata.schema.to_arrow_schema().equals(exepcted_schema, check_metadata=False)

expected_lines = [
"{",
' "catalog_name": "small_sky_object_to_source",',
' "catalog_type": "association",',
' "total_rows": 148,',
r' "primary_catalog": ".*small_sky_object_catalog",',
' "primary_column": "id",',
' "primary_column_association": "object_id",',
r' "join_catalog": ".*small_sky_source_catalog",',
' "join_column": "object_id",',
' "join_column_association": "source_id",',
' "contains_leaf_files": true',
"}",
]

metadata_filename = os.path.join(small_sky_soap_args.catalog_path, "catalog_info.json")
assert_text_file_matches(expected_lines, metadata_filename)

0 comments on commit 9d0bfc4

Please sign in to comment.