Skip to content

Commit

Permalink
Index generation improvements (#248)
Browse files Browse the repository at this point in the history
* Use from_map for index generation

* Pass client to index runner.

---------

Co-authored-by: Melissa DeLucchi <[email protected]>
  • Loading branch information
delucchi-cmu and Melissa DeLucchi authored Mar 1, 2024
1 parent 6a0df97 commit ea60042
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 66 deletions.
79 changes: 36 additions & 43 deletions src/hipscat_import/index/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,31 @@

import dask.dataframe as dd
import numpy as np
from dask.distributed import progress, wait
import pandas as pd
from hipscat.io import paths
from hipscat.io.file_io import file_io
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN


def create_index(args):
def read_leaf_file(input_file, include_columns, include_hipscat_index, drop_duplicates):
"""Mapping function called once per input file.
Reads the leaf parquet file, and returns with appropriate columns and duplicates dropped."""
data = pd.read_parquet(
input_file,
columns=include_columns,
engine="pyarrow",
)

data = data.reset_index()
if not include_hipscat_index:
data = data.drop(columns=[HIPSCAT_ID_COLUMN])

if drop_duplicates:
data = data.drop_duplicates()
return data


def create_index(args, client):
"""Read primary column, indexing column, and other payload data,
and write to catalog directory."""
include_columns = [args.indexing_column]
Expand All @@ -19,15 +37,19 @@ def create_index(args):

index_dir = paths.append_paths_to_pointer(args.catalog_path, "index")

metadata_file = paths.get_parquet_metadata_pointer(args.input_catalog_path)

metadata = file_io.read_parquet_metadata(metadata_file)
data = dd.read_parquet(
path=args.input_catalog_path,
columns=include_columns,
engine="pyarrow",
dataset={"partitioning": {"flavor": "hive", "schema": metadata.schema.to_arrow_schema()}},
filesystem="arrow",
data = dd.from_map(
read_leaf_file,
[
paths.pixel_catalog_file(
catalog_base_dir=args.input_catalog.catalog_base_dir,
pixel_order=pixel.order,
pixel_number=pixel.pixel,
)
for pixel in args.input_catalog.get_healpix_pixels()
],
include_columns=include_columns,
include_hipscat_index=args.include_hipscat_index,
drop_duplicates=args.drop_duplicates,
)

if args.include_order_pixel:
Expand All @@ -36,49 +58,20 @@ def create_index(args):
data["Dir"] = data["Dir"].astype(np.uint64)
data["Npix"] = data["Npix"].astype(np.uint64)

# There are some silly dask things happening here:
# - Turn the existing index column into a regular column
# - If that had been the _hipscat_index, and we don't want it anymore, drop it
# - Create a new index, using our target indexing_column.
# Use division hints if provided.
data = data.reset_index()
if not args.include_hipscat_index:
data = data.drop(columns=[HIPSCAT_ID_COLUMN])

if args.division_hints is not None and len(args.division_hints) > 2:
data = data.set_index(args.indexing_column, divisions=args.division_hints)
else:
# Try to avoid this! It's expensive! See:
# https://docs.dask.org/en/latest/generated/dask.dataframe.DataFrame.set_index.html
data = data.set_index(args.indexing_column)

if args.drop_duplicates:
# More dask things:
# - Repartition the whole dataset to account for limited memory in
# pyarrow in the drop_duplicates implementation (
# "array cannot contain more than 2147483646 bytes")
# - Reset the index, so the indexing_column values can be considered
# when de-duping.
# - Drop duplicate rows
# - Set the index back to our indexing_column, but this time, the
# values are still sorted so it's cheaper.
data = (
data.repartition(partition_size=1_000_000_000)
.reset_index()
.drop_duplicates()
.set_index(args.indexing_column, sorted=True, partition_size=args.compute_partition_size)
)
else:
data = data.repartition(partition_size=args.compute_partition_size)
data = data.repartition(partition_size=args.compute_partition_size)

# Now just write it out to leaf parquet files!
result = data.to_parquet(
path=index_dir,
engine="pyarrow",
compute_kwargs={"partition_size": args.compute_partition_size},
)
if args.progress_bar: # pragma: no cover
progress(result)
else:
wait(result)
client.compute(result)
return len(data)
4 changes: 2 additions & 2 deletions src/hipscat_import/index/run_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from hipscat_import.pipeline_resume_plan import PipelineResumePlan


def run(args):
def run(args, client):
"""Run index creation pipeline."""
if not args:
raise TypeError("args is required and should be type IndexArguments")
if not isinstance(args, IndexArguments):
raise TypeError("args must be type IndexArguments")
rows_written = mr.create_index(args)
rows_written = mr.create_index(args, client)

# All done - write out the metadata
with tqdm(
Expand Down
2 changes: 1 addition & 1 deletion src/hipscat_import/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def pipeline_with_client(args: RuntimeArguments, client: Client):
if isinstance(args, ImportArguments):
catalog_runner.run(args, client)
elif isinstance(args, IndexArguments):
index_runner.run(args)
index_runner.run(args, client)
elif isinstance(args, MarginCacheArguments):
margin_runner.generate_margin_cache(args, client)
elif isinstance(args, SoapArguments):
Expand Down
28 changes: 14 additions & 14 deletions tests/hipscat_import/index/test_index_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_create_index(
small_sky_object_catalog,
assert_parquet_file_index,
tmp_path,
dask_client,
):
"""Create an index for simple object catalog"""
args = IndexArguments(
Expand All @@ -26,7 +27,7 @@ def test_create_index(
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_object_index", "index", "part.0.parquet")

Expand All @@ -43,7 +44,7 @@ def test_create_index(


@pytest.mark.dask
def test_create_index_no_hipscat_index(small_sky_object_catalog, tmp_path):
def test_create_index_no_hipscat_index(small_sky_object_catalog, tmp_path, dask_client):
"""Create an index for simple object catalog, without the _hipscat_index field."""
args = IndexArguments(
input_catalog_path=small_sky_object_catalog,
Expand All @@ -54,7 +55,7 @@ def test_create_index_no_hipscat_index(small_sky_object_catalog, tmp_path):
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_object_index", "index", "part.0.parquet")

Expand All @@ -64,7 +65,7 @@ def test_create_index_no_hipscat_index(small_sky_object_catalog, tmp_path):


@pytest.mark.dask
def test_create_index_no_order_pixel(small_sky_object_catalog, tmp_path):
def test_create_index_no_order_pixel(small_sky_object_catalog, tmp_path, dask_client):
"""Create an index for simple object catalog, without the partitioning columns,
Norder, Dir, and Npix."""
args = IndexArguments(
Expand All @@ -76,7 +77,7 @@ def test_create_index_no_order_pixel(small_sky_object_catalog, tmp_path):
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_object_index", "index", "part.0.parquet")

Expand All @@ -86,11 +87,7 @@ def test_create_index_no_order_pixel(small_sky_object_catalog, tmp_path):


@pytest.mark.dask
def test_create_index_source(
small_sky_source_catalog,
assert_parquet_file_index,
tmp_path,
):
def test_create_index_source(small_sky_source_catalog, assert_parquet_file_index, tmp_path, dask_client):
"""Create simple index for the source table."""
args = IndexArguments(
input_catalog_path=small_sky_source_catalog,
Expand All @@ -100,7 +97,7 @@ def test_create_index_source(
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_source_index", "index", "part.0.parquet")

Expand All @@ -122,6 +119,7 @@ def test_create_index_with_divisions(
small_sky_source_catalog,
assert_parquet_file_index,
tmp_path,
dask_client,
):
"""Create an index catalog for the large(r) source catalog, passing
some divisions hints. This should partition the final output according to
Expand All @@ -139,7 +137,7 @@ def test_create_index_with_divisions(
drop_duplicates=False,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_source_index", "index", "part.0.parquet")

Expand All @@ -161,6 +159,7 @@ def test_create_index_source_by_object(
small_sky_source_catalog,
assert_parquet_file_index,
tmp_path,
dask_client,
):
"""Create an index for the source table, using the source's object ID
as the indexing key."""
Expand All @@ -172,7 +171,7 @@ def test_create_index_source_by_object(
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_source_index", "index", "part.0.parquet")

Expand All @@ -193,6 +192,7 @@ def test_create_index_extra_columns(
small_sky_source_catalog,
assert_parquet_file_index,
tmp_path,
dask_client,
):
"""Create an index with some additional payload columns."""
args = IndexArguments(
Expand All @@ -204,7 +204,7 @@ def test_create_index_extra_columns(
overwrite=True,
progress_bar=False,
)
mr.create_index(args)
mr.create_index(args, dask_client)

output_file = os.path.join(tmp_path, "small_sky_source_index", "index", "part.0.parquet")

Expand Down
14 changes: 8 additions & 6 deletions tests/hipscat_import/index/test_run_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
def test_empty_args():
"""Runner should fail with empty arguments"""
with pytest.raises(TypeError, match="IndexArguments"):
runner.run(None)
runner.run(None, None)


def test_bad_args():
"""Runner should fail with mis-typed arguments"""
args = {"output_artifact_name": "bad_arg_type"}
with pytest.raises(TypeError, match="IndexArguments"):
runner.run(args)
runner.run(args, None)


@pytest.mark.dask
def test_run_index(
small_sky_object_catalog,
tmp_path,
dask_client,
):
"""Test appropriate metadata is written"""

Expand All @@ -40,7 +41,7 @@ def test_run_index(
overwrite=True,
progress_bar=False,
)
runner.run(args)
runner.run(args, dask_client)

# Check that the catalog metadata file exists
catalog = Dataset.read_from_hipscat(args.catalog_path)
Expand Down Expand Up @@ -72,6 +73,7 @@ def test_run_index(
def test_run_index_on_source(
small_sky_source_catalog,
tmp_path,
dask_client,
):
"""Test appropriate metadata is written, when primary catalog covers multiple pixels."""

Expand All @@ -83,7 +85,7 @@ def test_run_index_on_source(
overwrite=True,
progress_bar=False,
)
runner.run(args)
runner.run(args, dask_client)

# Check that the catalog metadata file exists
catalog = Dataset.read_from_hipscat(args.catalog_path)
Expand Down Expand Up @@ -114,7 +116,7 @@ def test_run_index_on_source(
@pytest.mark.dask
def test_run_index_on_source_object_id(
small_sky_source_catalog,
dask_client, # pylint: disable=unused-argument
dask_client,
tmp_path,
assert_parquet_file_index,
):
Expand All @@ -129,7 +131,7 @@ def test_run_index_on_source_object_id(
include_hipscat_index=False,
progress_bar=False,
)
runner.run(args)
runner.run(args, dask_client)

# Check that the catalog metadata file exists
catalog = Dataset.read_from_hipscat(args.catalog_path)
Expand Down

0 comments on commit ea60042

Please sign in to comment.