Skip to content

Commit

Permalink
Merge pull request #364 from astronomy-commons/sean/dask-expressions
Browse files Browse the repository at this point in the history
Change to using dask expressions
  • Loading branch information
smcguire-cmu authored Jun 24, 2024
2 parents 79d4cb3 + 846bc3a commit 7c52201
Show file tree
Hide file tree
Showing 19 changed files with 103 additions and 97 deletions.
2 changes: 1 addition & 1 deletion src/lsdb/catalog/association_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class AssociationCatalog(HealpixDataset):

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.AssociationCatalog,
):
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Catalog(HealpixDataset):

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.Catalog,
margin: MarginCatalog | None = None,
Expand Down Expand Up @@ -328,7 +328,7 @@ def merge(
left_index: bool = False,
right_index: bool = False,
suffixes: Tuple[str, str] | None = None,
) -> dd.core.DataFrame:
) -> dd.DataFrame:
"""Performs the merge of two catalog Dataframes
More information about pandas merge is available
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Dataset:

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
hc_structure: hc.catalog.Dataset,
):
"""Initialise a Catalog object.
Expand Down
73 changes: 32 additions & 41 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import Any, Callable, Dict, Iterable, List, Tuple, cast
from typing import Any, Callable, Dict, Iterable, List, Tuple

import dask
import dask.dataframe as dd
Expand All @@ -21,7 +21,6 @@
from lsdb.catalog.dataset.dataset import Dataset
from lsdb.core.plotting.skymap import compute_skymap, perform_inner_skymap
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.dask.divisions import get_pixels_divisions
from lsdb.types import DaskDFPixelMap


Expand All @@ -39,7 +38,7 @@ class HealpixDataset(Dataset):

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: HCHealpixDataset,
):
Expand All @@ -58,7 +57,7 @@ def __init__(

def __getitem__(self, item):
result = self._ddf.__getitem__(item)
if isinstance(result, dd.core.DataFrame):
if isinstance(result, dd.DataFrame):
return self.__class__(result, self._ddf_pixel_map, self.hc_structure)
return result

Expand All @@ -80,7 +79,7 @@ def get_ordered_healpix_pixels(self) -> List[HealpixPixel]:
pixels = self.get_healpix_pixels()
return np.array(pixels)[get_pixel_argsort(pixels)]

def get_partition(self, order: int, pixel: int) -> dd.core.DataFrame:
def get_partition(self, order: int, pixel: int) -> dd.DataFrame:
"""Get the dask partition for a given HEALPix pixel
Args:
Expand Down Expand Up @@ -132,7 +131,7 @@ def _perform_search(
self,
metadata: hc.catalog.Catalog | hc.catalog.MarginCatalog,
search: AbstractSearch,
) -> Tuple[dict, dd.core.DataFrame]:
) -> Tuple[DaskDFPixelMap, dd.DataFrame]:
"""Performs a search on the catalog from a list of pixels to search in
Args:
Expand All @@ -145,33 +144,20 @@ def _perform_search(
A tuple containing a dictionary mapping pixel to partition index and a dask dataframe
containing the search results
"""
partitions = self._ddf.to_delayed()
filtered_pixels = metadata.get_healpix_pixels()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = (
[search.search_points(partition, metadata.catalog_info) for partition in targeted_partitions]
if search.fine
else targeted_partitions
)
return self._construct_search_ddf(filtered_pixels, filtered_partitions)

def _construct_search_ddf(
self, filtered_pixels: List[HealpixPixel], filtered_partitions: List[Delayed]
) -> Tuple[dict, dd.core.DataFrame]:
"""Constructs a search catalog pixel map and respective Dask Dataframe
Args:
filtered_pixels (List[HealpixPixel]): The list of pixels in the search
filtered_partitions (List[Delayed]): The list of delayed partitions
Returns:
The catalog pixel map and the respective Dask DataFrame
"""
divisions = get_pixels_divisions(filtered_pixels)
search_ddf = dd.io.from_delayed(filtered_partitions, meta=self._ddf._meta, divisions=divisions)
search_ddf = cast(dd.core.DataFrame, search_ddf)
if len(filtered_pixels) == 0:
return {}, dd.from_pandas(self._ddf._meta)
target_partitions_indices = [self._ddf_pixel_map[pixel] for pixel in filtered_pixels]
filtered_partitions_ddf = self._ddf.partitions[target_partitions_indices]
if search.fine:
filtered_partitions_ddf = filtered_partitions_ddf.map_partitions(
search.search_points,
metadata.catalog_info,
meta=self._ddf._meta,
transform_divisions=False,
)
ddf_partition_map = {pixel: i for i, pixel in enumerate(filtered_pixels)}
return ddf_partition_map, search_ddf
return ddf_partition_map, filtered_partitions_ddf

def map_partitions(
self,
Expand All @@ -180,7 +166,7 @@ def map_partitions(
meta: pd.DataFrame | pd.Series | Dict | Iterable | Tuple | None = None,
include_pixel: bool = False,
**kwargs,
) -> Self | dd.core.Series:
) -> Self | dd.Series:
"""Applies a function to each partition in the catalog.
The ra and dec of each row is assumed to remain unchanged.
Expand Down Expand Up @@ -236,7 +222,7 @@ def apply_func(df, *args, partition_info=None, **kwargs):
else:
output_ddf = self._ddf.map_partitions(func, *args, meta=meta, **kwargs)

if isinstance(output_ddf, dd.core.DataFrame):
if isinstance(output_ddf, dd.DataFrame):
return self.__class__(output_ddf, self._ddf_pixel_map, self.hc_structure)
warnings.warn(
"output of the function must be a DataFrame to generate an LSDB `Catalog`. `map_partitions` "
Expand All @@ -258,30 +244,35 @@ def prune_empty_partitions(self, persist: bool = False) -> Self:
if persist:
self._ddf.persist()
non_empty_pixels, non_empty_partitions = self._get_non_empty_partitions()
ddf_partition_map, search_ddf = self._construct_search_ddf(non_empty_pixels, non_empty_partitions)
search_ddf = (
self._ddf.partitions[non_empty_partitions]
if len(non_empty_partitions) > 0
else dd.from_pandas(self._ddf._meta, npartitions=1)
)
ddf_partition_map = {pixel: i for i, pixel in enumerate(non_empty_pixels)}
filtered_hc_structure = self.hc_structure.filter_from_pixel_list(non_empty_pixels)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)

def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], List[Delayed]]:
def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], np.ndarray]:
"""Determines which pixels and partitions of a catalog are not empty
Returns:
A tuple with the non-empty pixels and respective partitions
"""
partitions = self._ddf.to_delayed()

# Compute partition lengths (expensive operation)
partition_sizes = self._ddf.map_partitions(len).compute()
empty_partition_indices = np.argwhere(partition_sizes == 0).flatten()
non_empty_partition_indices = np.argwhere(partition_sizes > 0).flatten()

non_empty_indices_set = set(non_empty_partition_indices)

# Extract the non-empty pixels and respective partitions
non_empty_pixels, non_empty_partitions = [], []
non_empty_pixels = []
for pixel, partition_index in self._ddf_pixel_map.items():
if partition_index not in empty_partition_indices:
if partition_index in non_empty_indices_set:
non_empty_pixels.append(pixel)
non_empty_partitions.append(partitions[partition_index])

return non_empty_pixels, non_empty_partitions
return non_empty_pixels, non_empty_partition_indices

def skymap_data(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class MarginCatalog(HealpixDataset):

def __init__(
self,
ddf: dd.core.DataFrame,
ddf: dd.DataFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.MarginCatalog,
):
Expand Down
2 changes: 0 additions & 2 deletions src/lsdb/core/search/box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Tuple

import dask
import numpy as np
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
Expand Down Expand Up @@ -40,7 +39,6 @@ def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFr
return box_filter(frame, self.ra, self.dec, metadata)


@dask.delayed
def box_filter(
data_frame: pd.DataFrame,
ra: Tuple[float, float] | None,
Expand Down
2 changes: 0 additions & 2 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import dask
import pandas as pd
from astropy.coordinates import SkyCoord
from hipscat.catalog.catalog_info import CatalogInfo
Expand Down Expand Up @@ -32,7 +31,6 @@ def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFr
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata)


@dask.delayed
def cone_filter(data_frame: pd.DataFrame, ra, dec, radius_arcsec, metadata: CatalogInfo):
"""Filters a dataframe to only include points within the specified cone
Expand Down
2 changes: 0 additions & 2 deletions src/lsdb/core/search/polygon_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Tuple

import dask
import healpy as hp
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -35,7 +34,6 @@ def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFr
return polygon_filter(frame, self.polygon, metadata)


@dask.delayed
def polygon_filter(data_frame: pd.DataFrame, polygon: ConvexPolygon, metadata: CatalogInfo):
"""Filters a dataframe to only include points within the specified polygon.
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/dask/crossmatch_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def crossmatch_catalog_data(
Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm
) = BuiltInCrossmatchAlgorithm.KD_TREE,
**kwargs,
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Cross-matches the data from two catalogs
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/dask/join_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def perform_join_through(

def join_catalog_data_on(
left: Catalog, right: Catalog, left_on: str, right_on: str, suffixes: Tuple[str, str]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Joins two catalogs spatially on a specified column
Args:
Expand Down Expand Up @@ -218,7 +218,7 @@ def join_catalog_data_on(

def join_catalog_data_through(
left: Catalog, right: Catalog, association: AssociationCatalog, suffixes: Tuple[str, str]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Joins two catalogs with an association table
Args:
Expand Down
10 changes: 5 additions & 5 deletions src/lsdb/dask/merge_catalog_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from dask.delayed import Delayed
from dask.delayed import Delayed, delayed
from hipscat.catalog import PartitionInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN, healpix_to_hipscat_id
Expand Down Expand Up @@ -177,7 +177,7 @@ def filter_by_hipscat_index_to_pixel(dataframe: pd.DataFrame, order: int, pixel:

def construct_catalog_args(
partitions: List[Delayed], meta_df: pd.DataFrame, alignment: PixelAlignment
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Constructs the arguments needed to create a catalog from a list of delayed partitions
Args:
Expand All @@ -191,11 +191,11 @@ def construct_catalog_args(
"""
# generate dask df partition map from alignment
partition_map = get_partition_map_from_alignment_pixels(alignment.pixel_mapping)

# create dask df from delayed partitions
divisions = get_pixels_divisions(list(partition_map.keys()))
ddf = dd.io.from_delayed(partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.core.DataFrame, ddf)
partitions = partitions if len(partitions) > 0 else [delayed(pd.DataFrame([]))]
ddf = dd.from_delayed(partitions, meta=meta_df, divisions=divisions)
ddf = cast(dd.DataFrame, ddf)
return ddf, partition_map, alignment


Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/loaders/dataframe/dataframe_catalog_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _compute_pixel_map(self) -> Dict[HealpixPixel, HealpixInfo]:

def _generate_dask_df_and_map(
self, pixel_map: Dict[HealpixPixel, HealpixInfo]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, int]:
) -> Tuple[dd.DataFrame, DaskDFPixelMap, int]:
"""Load Dask DataFrame from HEALPix pixel Dataframes and
generate a mapping of HEALPix pixels to HEALPix Dataframes
Expand Down
6 changes: 3 additions & 3 deletions src/lsdb/loaders/dataframe/from_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def _generate_dask_dataframe(
pixel_dfs: List[pd.DataFrame], pixels: List[HealpixPixel], use_pyarrow_types: bool = True
) -> Tuple[dd.core.DataFrame, int]:
) -> Tuple[dd.DataFrame, int]:
"""Create the Dask Dataframe from the list of HEALPix pixel Dataframes
Args:
Expand All @@ -29,8 +29,8 @@ def _generate_dask_dataframe(
schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else []
delayed_dfs = [delayed(df) for df in pixel_dfs]
divisions = get_pixels_divisions(pixels)
ddf = dd.io.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
ddf = ddf if isinstance(ddf, dd.core.DataFrame) else ddf.to_frame()
ddf = dd.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
ddf = ddf if isinstance(ddf, dd.DataFrame) else ddf.to_frame()
return ddf, len(ddf)


Expand Down
Loading

0 comments on commit 7c52201

Please sign in to comment.