Skip to content

Commit

Permalink
Merge pull request #368 from astronomy-commons/nested
Browse files Browse the repository at this point in the history
Replace dask dataframe with nested-dask
  • Loading branch information
smcguire-cmu committed Aug 15, 2024
2 parents 4d4b1c8 + 18516c6 commit 72889fa
Show file tree
Hide file tree
Showing 38 changed files with 511 additions and 176 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ dependencies = [
# dask distributed eases the creation of parallel dask clients.
# dask diagnostics is required to spin up the dashboard for profiling.
"dask[complete]",
"deprecated",
"hipscat>=0.3.8",
"lsst-sphgeom", # To handle spherical sky polygons
"nested-dask",
"nested-pandas",
"pyarrow",
"deprecated",
"scipy", # kdtree
"lsst-sphgeom", # To handle spherical sky polygons
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/catalog/association_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dask.dataframe as dd
import hipscat as hc
import nested_dask as nd

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.types import DaskDFPixelMap
Expand All @@ -18,7 +18,7 @@ class AssociationCatalog(HealpixDataset):

def __init__(
self,
ddf: dd.DataFrame,
ddf: nd.NestedFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.AssociationCatalog,
):
Expand Down
82 changes: 72 additions & 10 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import dataclasses
from typing import List, Tuple, Type

import dask.dataframe as dd
import hipscat as hc
import nested_dask as nd
import nested_pandas as npd
import pandas as pd
from hipscat.catalog.index.index_catalog import IndexCatalog as HCIndexCatalog
from hipscat.pixel_math.polygon_filter import SphericalCoordinates
Expand All @@ -18,7 +19,11 @@
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.core.search.pixel_search import PixelSearch
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.join_catalog_data import join_catalog_data_on, join_catalog_data_through
from lsdb.dask.join_catalog_data import (
join_catalog_data_nested,
join_catalog_data_on,
join_catalog_data_through,
)
from lsdb.dask.partition_indexer import PartitionIndexer
from lsdb.io.schema import get_arrow_schema
from lsdb.types import DaskDFPixelMap
Expand All @@ -38,7 +43,7 @@ class Catalog(HealpixDataset):

def __init__(
self,
ddf: dd.DataFrame,
ddf: nd.NestedFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.Catalog,
margin: MarginCatalog | None = None,
Expand All @@ -61,7 +66,7 @@ def partitions(self):
"""Returns the partitions of the catalog"""
return PartitionIndexer(self)

def head(self, n: int = 5) -> pd.DataFrame:
def head(self, n: int = 5) -> npd.NestedFrame:
"""Returns a few rows of data for previewing purposes.
Args:
Expand All @@ -80,7 +85,7 @@ def head(self, n: int = 5) -> pd.DataFrame:
dfs.append(partition_head)
remaining_rows -= len(partition_head)
if len(dfs) > 0:
return pd.concat(dfs)
return npd.NestedFrame(pd.concat(dfs))
return self._ddf._meta

def query(self, expr: str) -> Catalog:
Expand Down Expand Up @@ -154,8 +159,8 @@ def crossmatch(
The class will have been initialized with the following parameters, which the
crossmatch function should use:
- left: pd.DataFrame,
- right: pd.DataFrame,
- left: npd.NestedFrame,
- right: npd.NestedFrame,
- left_order: int,
- left_pixel: int,
- right_order: int,
Expand Down Expand Up @@ -315,9 +320,9 @@ def search(self, search: AbstractSearch):
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ddf = self._perform_search(filtered_hc_structure, search)
ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search)
margin = self.margin.search(search) if self.margin is not None else None
return Catalog(search_ddf, ddf_partition_map, filtered_hc_structure, margin=margin)
return Catalog(search_ndf, ddf_partition_map, filtered_hc_structure, margin=margin)

def merge(
self,
Expand All @@ -329,7 +334,7 @@ def merge(
left_index: bool = False,
right_index: bool = False,
suffixes: Tuple[str, str] | None = None,
) -> dd.DataFrame:
) -> nd.NestedFrame:
"""Performs the merge of two catalog Dataframes
More information about pandas merge is available
Expand Down Expand Up @@ -444,3 +449,60 @@ def join(
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf))
return Catalog(ddf, ddf_map, hc_catalog)

def join_nested(
self,
other: Catalog,
left_on: str | None = None,
right_on: str | None = None,
nested_column_name: str | None = None,
output_catalog_name: str | None = None,
) -> Catalog:
"""Perform a spatial join to another catalog by adding the other catalog as a nested column
Joins two catalogs together on a shared column value, merging rows where they match.
The result is added as a nested dataframe column using
`nested-dask <https://github.com/lincc-frameworks/nested-dask>`__, where the right catalog's columns
are encoded within a column in the resulting dataframe. For more information, view the
`nested-dask documentation <https://nested-dask.readthedocs.io/en/latest/>`__.
The operation only joins data from matching partitions and their margin caches, and does not join rows
that have a matching column value but are in separate partitions in the sky. For a more general join,
see the `merge` function.
Args:
other (Catalog): the right catalog to join to
left_on (str): the name of the column in the left catalog to join on
right_on (str): the name of the column in the right catalog to join on
nested_column_name (str): the name of the nested column in the resulting dataframe storing the
joined columns in the right catalog. (Default: name of right catalog)
output_catalog_name (str): The name of the resulting catalog to be stored in metadata
Returns:
A new catalog with the columns from each of the input catalogs with their respective suffixes
added, and the rows merged on the specified columns.
"""

if left_on is None or right_on is None:
raise ValueError("Both of left_on and right_on must be set")

if left_on not in self._ddf.columns:
raise ValueError("left_on must be a column in the left catalog")

if right_on not in other._ddf.columns:
raise ValueError("right_on must be a column in the right catalog")

ddf, ddf_map, alignment = join_catalog_data_nested(
self, other, left_on, right_on, nested_column_name=nested_column_name
)

if output_catalog_name is None:
output_catalog_name = self.hc_structure.catalog_info.catalog_name

new_catalog_info = dataclasses.replace(
self.hc_structure.catalog_info,
catalog_name=output_catalog_name,
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
return Catalog(ddf, ddf_map, hc_catalog)
8 changes: 4 additions & 4 deletions src/lsdb/catalog/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

import dask.dataframe as dd
import hipscat as hc
import pandas as pd
import nested_dask as nd
import nested_pandas as npd
from dask.delayed import Delayed


Expand All @@ -11,7 +11,7 @@ class Dataset:

def __init__(
self,
ddf: dd.DataFrame,
ddf: nd.NestedFrame,
hc_structure: hc.catalog.Dataset,
):
"""Initialise a Catalog object.
Expand All @@ -34,7 +34,7 @@ def _repr_html_(self):
data = self._ddf._repr_data().to_html(max_rows=5, show_dimensions=False, notebook=True)
return f"<div><strong>lsdb Catalog {self.name}:</strong></div>{data}"

def compute(self) -> pd.DataFrame:
def compute(self) -> npd.NestedFrame:
"""Compute dask distributed dataframe to pandas dataframe"""
return self._ddf.compute()

Expand Down
50 changes: 27 additions & 23 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import dask.dataframe as dd
import healpy as hp
import hipscat as hc
import nested_dask as nd
import nested_pandas as npd
import numpy as np
import pandas as pd
from dask.delayed import Delayed, delayed
Expand Down Expand Up @@ -38,7 +40,7 @@ class HealpixDataset(Dataset):

def __init__(
self,
ddf: dd.DataFrame,
ddf: nd.NestedFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: HCHealpixDataset,
):
Expand All @@ -57,7 +59,7 @@ def __init__(

def __getitem__(self, item):
result = self._ddf.__getitem__(item)
if isinstance(result, dd.DataFrame):
if isinstance(result, nd.NestedFrame):
return self.__class__(result, self._ddf_pixel_map, self.hc_structure)
return result

Expand All @@ -79,7 +81,7 @@ def get_ordered_healpix_pixels(self) -> List[HealpixPixel]:
pixels = self.get_healpix_pixels()
return np.array(pixels)[get_pixel_argsort(pixels)]

def get_partition(self, order: int, pixel: int) -> dd.DataFrame:
def get_partition(self, order: int, pixel: int) -> nd.NestedFrame:
"""Get the dask partition for a given HEALPix pixel
Args:
Expand Down Expand Up @@ -124,14 +126,14 @@ def query(self, expr: str) -> Self:
A catalog that contains the data from the original catalog that complies
with the query expression
"""
ddf = self._ddf.query(expr)
return self.__class__(ddf, self._ddf_pixel_map, self.hc_structure)
ndf = self._ddf.query(expr)
return self.__class__(ndf, self._ddf_pixel_map, self.hc_structure)

def _perform_search(
self,
metadata: hc.catalog.Catalog | hc.catalog.MarginCatalog,
search: AbstractSearch,
) -> Tuple[DaskDFPixelMap, dd.DataFrame]:
) -> Tuple[DaskDFPixelMap, nd.NestedFrame]:
"""Performs a search on the catalog from a list of pixels to search in
Args:
Expand All @@ -146,7 +148,7 @@ def _perform_search(
"""
filtered_pixels = metadata.get_healpix_pixels()
if len(filtered_pixels) == 0:
return {}, dd.from_pandas(self._ddf._meta)
return {}, nd.NestedFrame.from_pandas(self._ddf._meta)
target_partitions_indices = [self._ddf_pixel_map[pixel] for pixel in filtered_pixels]
filtered_partitions_ddf = self._ddf.partitions[target_partitions_indices]
if search.fine:
Expand All @@ -161,7 +163,7 @@ def _perform_search(

def map_partitions(
self,
func: Callable[..., pd.DataFrame],
func: Callable[..., npd.NestedFrame],
*args,
meta: pd.DataFrame | pd.Series | Dict | Iterable | Tuple | None = None,
include_pixel: bool = False,
Expand All @@ -173,11 +175,11 @@ def map_partitions(
Args:
func (Callable): The function applied to each partition, which will be called with:
`func(partition: pd.DataFrame, *args, **kwargs)` with the additional args and kwargs passed to
the `map_partitions` function. If the `include_pixel` parameter is set, the function will be
called with the `healpix_pixel` as the second positional argument set to the healpix pixel
`func(partition: npd.NestedFrame, *args, **kwargs)` with the additional args and kwargs passed
to the `map_partitions` function. If the `include_pixel` parameter is set, the function will
be called with the `healpix_pixel` as the second positional argument set to the healpix pixel
of the partition as
`func(partition: pd.DataFrame, healpix_pixel: HealpixPixel, *args, **kwargs)`
`func(partition: npd.NestedFrame, healpix_pixel: HealpixPixel, *args, **kwargs)`
*args: Additional positional arguments to call `func` with.
meta (pd.DataFrame | pd.Series | Dict | Iterable | Tuple | None): An empty pandas DataFrame that
has columns matching the output of the function applied to a partition. Other types are
Expand Down Expand Up @@ -222,8 +224,10 @@ def apply_func(df, *args, partition_info=None, **kwargs):
else:
output_ddf = self._ddf.map_partitions(func, *args, meta=meta, **kwargs)

if isinstance(output_ddf, dd.DataFrame):
return self.__class__(output_ddf, self._ddf_pixel_map, self.hc_structure)
if isinstance(output_ddf, nd.NestedFrame) | isinstance(output_ddf, dd.DataFrame):
return self.__class__(
nd.NestedFrame.from_dask_dataframe(output_ddf), self._ddf_pixel_map, self.hc_structure
)
warnings.warn(
"output of the function must be a DataFrame to generate an LSDB `Catalog`. `map_partitions` "
"will return a dask object instead of a Catalog.",
Expand All @@ -247,7 +251,7 @@ def prune_empty_partitions(self, persist: bool = False) -> Self:
search_ddf = (
self._ddf.partitions[non_empty_partitions]
if len(non_empty_partitions) > 0
else dd.from_pandas(self._ddf._meta, npartitions=1)
else nd.NestedFrame.from_pandas(self._ddf._meta, npartitions=1)
)
ddf_partition_map = {pixel: i for i, pixel in enumerate(non_empty_pixels)}
filtered_hc_structure = self.hc_structure.filter_from_pixel_list(non_empty_pixels)
Expand Down Expand Up @@ -276,15 +280,15 @@ def _get_non_empty_partitions(self) -> Tuple[List[HealpixPixel], np.ndarray]:

def skymap_data(
self,
func: Callable[[pd.DataFrame, HealpixPixel], Any],
func: Callable[[npd.NestedFrame, HealpixPixel], Any],
order: int | None = None,
default_value: Any = 0.0,
**kwargs,
) -> Dict[HealpixPixel, Delayed]:
"""Perform a function on each partition of the catalog, returning a dict of values for each pixel.
Args:
func (Callable[[pd.DataFrame, HealpixPixel], Any]): A function that takes a pandas
func (Callable[[npd.NestedFrame, HealpixPixel], Any]): A function that takes a pandas
DataFrame with the data in a partition, the HealpixPixel of the partition, and any other
keyword arguments and returns an aggregated value
order (int | None): The HEALPix order to compute the skymap at. If None (default),
Expand Down Expand Up @@ -321,7 +325,7 @@ def skymap_data(

def skymap_histogram(
self,
func: Callable[[pd.DataFrame, HealpixPixel], Any],
func: Callable[[npd.NestedFrame, HealpixPixel], Any],
order: int | None = None,
default_value: Any = 0.0,
**kwargs,
Expand All @@ -330,8 +334,8 @@ def skymap_histogram(
a given order
Args:
func (Callable[[pd.DataFrame], HealpixPixel, Any]): A function that takes a pandas DataFrame and
the HealpixPixel the partition is from and returns a value
func (Callable[[npd.NestedFrame, HealpixPixel], Any]): A function that takes a pandas DataFrame
and the HealpixPixel the partition is from and returns a value
order (int | None): The HEALPix order to compute the skymap at. If None (default),
will compute for each partition in the catalog at their own orders. If a value
other than None, each partition will be grouped by pixel number at the order
Expand All @@ -357,7 +361,7 @@ def skymap_histogram(

def skymap(
self,
func: Callable[[pd.DataFrame, HealpixPixel], Any],
func: Callable[[npd.NestedFrame, HealpixPixel], Any],
order: int | None = None,
default_value: Any = hp.pixelfunc.UNSEEN,
projection="moll",
Expand All @@ -367,8 +371,8 @@ def skymap(
"""Plot a skymap of an aggregate function applied over each partition
Args:
func (Callable[[pd.DataFrame], HealpixPixel, Any]): A function that takes a pandas DataFrame and
the HealpixPixel the partition is from and returns a value
func (Callable[[npd.NestedFrame, HealpixPixel], Any]): A function that takes a pandas DataFrame
and the HealpixPixel the partition is from and returns a value
order (int | None): The HEALPix order to compute the skymap at. If None (default),
will compute for each partition in the catalog at their own orders. If a value
other than None, each partition will be grouped by pixel number at the order
Expand Down
8 changes: 4 additions & 4 deletions src/lsdb/catalog/margin_catalog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dask.dataframe as dd
import hipscat as hc
import nested_dask as nd

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.core.search.abstract_search import AbstractSearch
Expand All @@ -19,7 +19,7 @@ class MarginCatalog(HealpixDataset):

def __init__(
self,
ddf: dd.DataFrame,
ddf: nd.NestedFrame,
ddf_pixel_map: DaskDFPixelMap,
hc_structure: hc.catalog.MarginCatalog,
):
Expand All @@ -38,5 +38,5 @@ def search(self, search: AbstractSearch):
A new Catalog containing the points filtered to those matching the search parameters.
"""
filtered_hc_structure = search.filter_hc_catalog(self.hc_structure)
ddf_partition_map, search_ddf = self._perform_search(filtered_hc_structure, search)
return self.__class__(search_ddf, ddf_partition_map, filtered_hc_structure)
ddf_partition_map, search_ndf = self._perform_search(filtered_hc_structure, search)
return self.__class__(search_ndf, ddf_partition_map, filtered_hc_structure)
Loading

0 comments on commit 72889fa

Please sign in to comment.