Skip to content

Commit

Permalink
Merge branch 'main' into delucchi/discovery
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Mar 3, 2025
2 parents dd03342 + 98205be commit 82c152c
Show file tree
Hide file tree
Showing 24 changed files with 186 additions and 88 deletions.
1 change: 0 additions & 1 deletion src/hats_import/margin_cache/margin_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def generate_margin_cache(args, client):
output_path=args.catalog_path,
partition_order=pix.order,
partition_pixel=pix.pixel,
original_catalog_metadata=original_catalog_metadata,
delete_intermediate_parquet_files=args.delete_intermediate_parquet_files,
)
)
Expand Down
85 changes: 45 additions & 40 deletions src/hats_import/margin_cache/margin_cache_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from hats.io import file_io, paths
from hats.pixel_math.healpix_pixel import HealpixPixel
from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN

from hats_import.margin_cache.margin_cache_resume_plan import MarginCachePlan
from hats_import.pipeline_resume_plan import get_pixel_cache_directory, print_task_failure


# pylint: disable=too-many-arguments, unused-argument
# pylint: disable=too-many-arguments,disable=too-many-locals
def map_pixel_shards(
partition_file,
mapping_key,
Expand All @@ -29,8 +31,10 @@ def map_pixel_shards(
raise NotImplementedError("Fine filtering temporarily removed.")

schema = file_io.read_parquet_metadata(original_catalog_metadata).schema.to_arrow_schema()
data = file_io.read_parquet_file_to_pandas(partition_file, schema=schema)
source_pixel = HealpixPixel(data["Norder"].iloc[0], data["Npix"].iloc[0])
data = pq.read_table(
partition_file.path, filesystem=partition_file.fs, schema=schema
).combine_chunks()
source_pixel = HealpixPixel(data["Norder"][0].as_py(), data["Npix"][0].as_py())

# Constrain the possible margin pairs, first by only those `margin_order` pixels
# that **can** be contained in source pixel, then by `margin_order` pixels for rows
Expand All @@ -45,8 +49,8 @@ def map_pixel_shards(

margin_pixel_list = hp.radec2pix(
margin_order,
data[ra_column].values,
data[dec_column].values,
data[ra_column].to_numpy(),
data[dec_column].to_numpy(),
)
margin_pixel_filter = pd.DataFrame(
{"margin_pixel": margin_pixel_list, "filter_value": np.arange(0, len(margin_pixel_list))}
Expand All @@ -57,10 +61,10 @@ def map_pixel_shards(
# and write out shard file.
num_rows = 0
for partition_key, data_filter in margin_pixel_filter.groupby(["partition_order", "partition_pixel"]):
data_filter = np.unique(data_filter["filter_value"]).tolist()
data_filter = np.unique(data_filter["filter_value"])
filtered_data = data.take(data_filter)
pixel = HealpixPixel(partition_key[0], partition_key[1])

filtered_data = data.iloc[data_filter]
num_rows += _to_pixel_shard(
filtered_data=filtered_data,
pixel=pixel,
Expand Down Expand Up @@ -102,29 +106,38 @@ def _to_pixel_shard(

shard_path = paths.pixel_catalog_file(partition_dir, source_pixel)

rename_columns = {
paths.PARTITION_ORDER: paths.MARGIN_ORDER,
paths.PARTITION_DIR: paths.MARGIN_DIR,
paths.PARTITION_PIXEL: paths.MARGIN_PIXEL,
}
margin_data = _rename_original_pixel_columns(margin_data)
margin_data = _append_margin_pixel_columns(margin_data, pixel)
margin_data = margin_data.sort_by(SPATIAL_INDEX_COLUMN)

margin_data = margin_data.rename(columns=rename_columns)
pq.write_table(margin_data, shard_path.path, filesystem=shard_path.fs)
return num_rows

margin_data[paths.PARTITION_ORDER] = pixel.order
margin_data[paths.PARTITION_DIR] = pixel.dir
margin_data[paths.PARTITION_PIXEL] = pixel.pixel

margin_data = margin_data.astype(
{
paths.PARTITION_ORDER: np.uint8,
paths.PARTITION_DIR: np.uint64,
paths.PARTITION_PIXEL: np.uint64,
}
)
margin_data = margin_data.sort_index()
def _rename_original_pixel_columns(margin_data):
"""Rename source pixel columns to include margin prefix"""
rename_columns = {
paths.PARTITION_ORDER: paths.MARGIN_ORDER,
paths.PARTITION_DIR: paths.MARGIN_DIR,
paths.PARTITION_PIXEL: paths.MARGIN_PIXEL,
}
return margin_data.rename_columns(rename_columns)

margin_data.to_parquet(shard_path.path, filesystem=shard_path.fs)
return num_rows

def _append_margin_pixel_columns(margin_data, pixel):
"""Append margin pixel columns to the shard table"""
num_rows = len(margin_data)
order_values = pa.repeat(pa.scalar(pixel.order, type=pa.uint8()), num_rows)
dir_values = pa.repeat(pa.scalar(pixel.dir, type=pa.uint64()), num_rows)
pixel_values = pa.repeat(pa.scalar(pixel.pixel, type=pa.uint64()), num_rows)
pixel_columns = {
paths.PARTITION_ORDER: order_values,
paths.PARTITION_DIR: dir_values,
paths.PARTITION_PIXEL: pixel_values,
}
for col_name, col_values in pixel_columns.items():
margin_data = margin_data.append_column(col_name, col_values)
return margin_data


def reduce_margin_shards(
Expand All @@ -133,33 +146,25 @@ def reduce_margin_shards(
output_path,
partition_order,
partition_pixel,
original_catalog_metadata,
delete_intermediate_parquet_files,
):
"""Reduce all partition pixel directories into a single file"""
try:
healpix_pixel = HealpixPixel(partition_order, partition_pixel)
shard_dir = get_pixel_cache_directory(intermediate_directory, healpix_pixel)
if file_io.does_file_or_directory_exist(shard_dir):
schema = file_io.read_parquet_metadata(original_catalog_metadata).schema.to_arrow_schema()

schema = (
schema.append(pa.field(paths.MARGIN_ORDER, pa.uint8()))
.append(pa.field(paths.MARGIN_DIR, pa.uint64()))
.append(pa.field(paths.MARGIN_PIXEL, pa.uint64()))
)
data = ds.dataset(shard_dir, format="parquet", schema=schema)
full_df = data.to_table().to_pandas()
if file_io.does_file_or_directory_exist(shard_dir):
margin_table = ds.dataset(shard_dir.path, filesystem=shard_dir.fs, format="parquet").to_table()

if len(full_df):
if len(margin_table):
margin_cache_dir = paths.pixel_directory(output_path, partition_order, partition_pixel)
file_io.make_directory(margin_cache_dir, exist_ok=True)

margin_cache_file_path = paths.pixel_catalog_file(output_path, healpix_pixel)

full_df.to_parquet(
margin_cache_file_path.path, schema=schema, filesystem=margin_cache_file_path.fs
pq.write_table(
margin_table, margin_cache_file_path.path, filesystem=margin_cache_file_path.fs
)

if delete_intermediate_parquet_files:
file_io.remove_directory(shard_dir)

Expand Down
34 changes: 32 additions & 2 deletions tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"metadata": {},
"outputs": [],
"source": [
"import lsdb\n",
"import hats_import.pipeline as runner\n",
"from hats_import.catalog.arguments import ImportArguments\n",
"import tempfile\n",
Expand Down Expand Up @@ -124,6 +125,35 @@
"tmp_path.cleanup()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nested catalog: small_sky_nested_catalog"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"small_sky_object = lsdb.read_hats(\"small_sky_object_catalog\")\n",
"small_sky_source = lsdb.read_hats(\"small_sky_source_catalog\")\n",
"small_sky_nested = small_sky_object.join_nested(\n",
" small_sky_source, left_on=\"id\", right_on=\"object_id\", nested_column_name=\"lc\"\n",
")\n",
"small_sky_nested = small_sky_nested.map_partitions(\n",
" lambda df, p: df.assign(Norder=p.order, Npix=p.pixel, Dir=p.pixel // 10000), include_pixel=True\n",
")\n",
"lsdb.io.to_hats(\n",
" small_sky_nested,\n",
" base_catalog_path=\"small_sky_nested_catalog\",\n",
" catalog_name=\"small_sky_nested_catalog\",\n",
" histogram_order=5,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -369,7 +399,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "env",
"display_name": "lsdb",
"language": "python",
"name": "python3"
},
Expand All @@ -383,7 +413,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.11"
}
},
"nbformat": 4,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/data/small_sky_nested_catalog/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Norder,Npix
2,176
2,177
2,178
2,179
2,180
2,181
2,182
2,183
2,184
2,185
2,186
2,187
1,47
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/data/small_sky_nested_catalog/properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#HATS catalog
obs_collection=small_sky_nested_catalog
dataproduct_type=object
hats_nrows=131
hats_col_ra=ra
hats_col_dec=dec
hats_max_rows=1000000
hats_order=0
moc_sky_fraction=0.08333
hats_builder=hats-import v0.4.2.dev1+g57aaa9d
hats_creation_date=2024-11-07T15\:20UTC
hats_estsize=113
hats_release_date=2024-09-18
hats_version=v0.1
19 changes: 14 additions & 5 deletions tests/hats_import/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def small_sky_source_catalog(test_data_dir):
return test_data_dir / "small_sky_source_catalog"


@pytest.fixture
def small_sky_nested_catalog(test_data_dir):
return test_data_dir / "small_sky_nested_catalog"


@pytest.fixture
def blank_data_dir(test_data_dir):
return test_data_dir / "blank"
Expand Down Expand Up @@ -175,16 +180,18 @@ def basic_data_shard_df():
ras = np.arange(0.0, 360.0)
dec = np.full(360, 0.0)
norder = np.full(360, 1)
_dir = np.full(360, 0)
npix = np.full(360, 0)
spatial_indexes = pixel_math.compute_spatial_index(ras, dec)

test_df = pd.DataFrame(
data=zip(spatial_indexes, ras, dec, norder, npix),
data=zip(spatial_indexes, ras, dec, norder, _dir, npix),
columns=[
"_healpix_29",
"weird_ra",
"weird_dec",
"Norder",
"Dir",
"Npix",
],
)
Expand All @@ -196,16 +203,18 @@ def polar_data_shard_df():
ras = np.arange(0.0, 360.0)
dec = np.full(360, 89.9)
norder = np.full(360, 2)
_dir = np.full(360, 0)
npix = np.full(360, 0)
spatial_indexes = pixel_math.compute_spatial_index(ras, dec)

test_df = pd.DataFrame(
data=zip(spatial_indexes, ras, dec, norder, npix),
data=zip(spatial_indexes, ras, dec, norder, _dir, npix),
columns=[
"_healpix_29",
"weird_ra",
"weird_dec",
"Norder",
"Dir",
"Npix",
],
)
Expand Down Expand Up @@ -243,9 +252,9 @@ def assert_text_file_matches(expected_lines, file_name):
contents
), f"files not the same length ({len(contents)} vs {len(expected_lines)})"
for i, expected in enumerate(expected_lines):
assert re.match(expected, contents[i]), (
f"files do not match at line {i+1} " f"(actual: [{contents[i]}] vs expected: [{expected}])"
)
assert re.match(
expected, contents[i]
), f"files do not match at line {i + 1} (actual: [{contents[i]}] vs expected: [{expected}])"

return assert_text_file_matches

Expand Down
Loading

0 comments on commit 82c152c

Please sign in to comment.