Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap nested_dask reduce #414

Merged
merged 12 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,3 +651,9 @@ def dropna(
ignore_index=ignore_index,
)
return catalog

def reduce(self, func, *args, meta=None, **kwargs) -> Catalog:
catalog = super().reduce(func, *args, meta=meta, **kwargs)
if self.margin is not None:
catalog.margin = self.margin.reduce(func, *args, meta=meta, **kwargs)
return catalog
82 changes: 82 additions & 0 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import dataclasses
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Tuple
Expand Down Expand Up @@ -29,6 +30,8 @@
from lsdb.catalog.dataset.dataset import Dataset
from lsdb.core.plotting.skymap import compute_skymap, perform_inner_skymap
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.dask.merge_catalog_functions import concat_metas
from lsdb.io.schema import get_arrow_schema
from lsdb.types import DaskDFPixelMap


Expand Down Expand Up @@ -510,3 +513,82 @@ def dropna(
hc_structure = copy.copy(self.hc_structure)
hc_structure.catalog_info.total_rows = None
return self.__class__(ndf, self._ddf_pixel_map, hc_structure)

def reduce(self, func, *args, meta=None, append_columns=False, **kwargs) -> Self:
"""
Takes a function and applies it to each top-level row of the Catalog.

docstring copied from nested-pandas

The user may specify which columns the function is applied to, with
columns from the 'base' layer being passsed to the function as
scalars and columns from the nested layers being passed as numpy arrays.

Parameters
----------
func : callable
Function to apply to each nested dataframe. The first arguments to `func` should be which
columns to apply the function to. See the Notes for recommendations
on writing func outputs.
args : positional arguments
Positional arguments to pass to the function, the first *args should be the names of the
columns to apply the function to.
meta : dataframe or series-like, optional
The dask meta of the output. If append_columns is True, the meta should specify just the
additional columns output by func.
append_columns : bool
If the output columns should be appended to the orignal dataframe.
kwargs : keyword arguments, optional
Keyword arguments to pass to the function.

Returns
-------
`HealpixDataset`
`HealpixDataset` with the results of the function applied to the columns of the frame.

Notes
-----
By default, `reduce` will produce a `NestedFrame` with enumerated
column names for each returned value of the function. For more useful
naming, it's recommended to have `func` return a dictionary where each
key is an output column of the dataframe returned by `reduce`.

Example User Function:

>>> def my_sum(col1, col2):
>>> '''reduce will return a NestedFrame with two columns'''
>>> return {"sum_col1": sum(col1), "sum_col2": sum(col2)}
>>>
>>> catalog.reduce(my_sum, 'sources.col1', 'sources.col2')

"""

if append_columns:
meta = concat_metas([self._ddf._meta.copy(), meta])

catalog_info = self.hc_structure.catalog_info

def reduce_part(df):
reduced_result = npd.NestedFrame(df).reduce(func, *args, **kwargs)
if append_columns:
if catalog_info.ra_column in reduced_result or catalog_info.dec_column in reduced_result:
raise ValueError("ra and dec columns can not be modified using reduce")
return npd.NestedFrame(pd.concat([df, reduced_result], axis=1))
return reduced_result

ndf = nd.NestedFrame.from_dask_dataframe(self._ddf.map_partitions(reduce_part, meta=meta))

hc_catalog = self.hc_structure
if not append_columns:
new_catalog_info = dataclasses.replace(
self.hc_structure.catalog_info,
ra_column="",
dec_column="",
)
hc_catalog = self.hc_structure.__class__(
new_catalog_info,
self.hc_structure.pixel_tree,
schema=get_arrow_schema(ndf),
moc=self.hc_structure.moc,
)
return self.__class__(ndf, self._ddf_pixel_map, hc_catalog)
17 changes: 17 additions & 0 deletions src/lsdb/dask/merge_catalog_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import numpy.typing as npt
import pandas as pd
from dask.dataframe.dispatch import make_meta
from dask.delayed import Delayed, delayed
from hipscat.catalog import PartitionInfo
from hipscat.pixel_math import HealpixPixel
Expand Down Expand Up @@ -308,6 +309,22 @@ def generate_meta_df_for_nested_tables(
return npd.NestedFrame(meta_df).add_nested(nested_catalog_meta, nested_column_name)


def concat_metas(metas: Sequence[npd.NestedFrame | dict]):
"""Concats the columns of a sequence of dask metas into a single NestedFrame meta

Args:
metas (Sequence[dict | DataFrame]): A collection of dask meta inputs

Returns:
(npd.NestedFrame) An empty NestedFrame with the columns of the input metas concatenated together in
the order of the input sequence.
"""
pandas_metas = []
for meta in metas:
pandas_metas.append(npd.NestedFrame(make_meta(meta)))
return npd.NestedFrame(pd.concat(pandas_metas, axis=1))


def get_partition_map_from_alignment_pixels(join_pixels: pd.DataFrame) -> DaskDFPixelMap:
"""Gets a dictionary mapping HEALPix pixel to index of pixel in the pixel_mapping of a `PixelAlignment`

Expand Down
98 changes: 98 additions & 0 deletions tests/lsdb/catalog/test_nested.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import nested_dask as nd
import nested_pandas as npd
import numpy as np
import pandas as pd
import pytest

from lsdb import Catalog

Expand All @@ -11,6 +13,7 @@ def test_dropna(small_sky_with_nested_sources):
assert isinstance(drop_na_cat, Catalog)
assert isinstance(drop_na_cat._ddf, nd.NestedFrame)
drop_na_compute = drop_na_cat.compute()
assert isinstance(drop_na_compute, npd.NestedFrame)
filtered_compute = filtered_cat.compute()
assert len(drop_na_compute) < len(filtered_compute)
pd.testing.assert_frame_equal(drop_na_compute, filtered_compute.dropna())
Expand All @@ -35,3 +38,98 @@ def add_na_values_nested(df):
pd.testing.assert_frame_equal(
drop_na_cat.compute(), filtered_cat._ddf.dropna(on_nested="sources").compute()
)


def test_reduce(small_sky_with_nested_sources):
def mean_mag(ra, dec, mag):
return {"ra": ra, "dec": dec, "mean_mag": np.mean(mag)}

reduced_cat = small_sky_with_nested_sources.reduce(
mean_mag, "ra", "dec", "sources.mag", meta={"ra": float, "dec": float, "mean_mag": float}
)

assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

reduced_cat_compute = reduced_cat.compute()
assert isinstance(reduced_cat_compute, npd.NestedFrame)

reduced_ddf = small_sky_with_nested_sources._ddf.reduce(
mean_mag, "ra", "dec", "sources.mag", meta={"ra": float, "dec": float, "mean_mag": float}
)

pd.testing.assert_frame_equal(reduced_cat_compute, reduced_ddf.compute())


def test_reduce_append_columns(small_sky_with_nested_sources):
def mean_mag(mag):
return {"mean_mag": np.mean(mag)}

reduced_cat = small_sky_with_nested_sources.reduce(
mean_mag, "sources.mag", meta={"mean_mag": float}, append_columns=True
)

assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

reduced_cat_compute = reduced_cat.compute()
assert isinstance(reduced_cat_compute, npd.NestedFrame)

reduced_ddf = small_sky_with_nested_sources._ddf.reduce(mean_mag, "sources.mag", meta={"mean_mag": float})

pd.testing.assert_series_equal(reduced_cat_compute["mean_mag"], reduced_ddf.compute()["mean_mag"])
pd.testing.assert_frame_equal(
reduced_cat_compute[small_sky_with_nested_sources.columns], small_sky_with_nested_sources.compute()
)


def test_reduce_no_return_column(small_sky_with_nested_sources):
def mean_mag(mag):
return np.mean(mag)

reduced_cat = small_sky_with_nested_sources.reduce(
mean_mag, "sources.mag", meta={0: float}, append_columns=True
)

assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

reduced_cat_compute = reduced_cat.compute()
assert isinstance(reduced_cat_compute, npd.NestedFrame)

reduced_ddf = small_sky_with_nested_sources._ddf.reduce(mean_mag, "sources.mag", meta={0: float})

pd.testing.assert_series_equal(reduced_cat_compute[0], reduced_ddf.compute()[0])
pd.testing.assert_frame_equal(
reduced_cat_compute[small_sky_with_nested_sources.columns], small_sky_with_nested_sources.compute()
)


def test_reduce_invalid_return_column(small_sky_with_nested_sources):
def mean_mag(mag):
return pd.DataFrame.from_dict({"mean_mag": [np.mean(mag)]})

reduced_cat = small_sky_with_nested_sources.reduce(
mean_mag, "sources.mag", meta={0: float}, append_columns=True
)

assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

with pytest.raises(ValueError):
reduced_cat.compute()


def test_reduce_append_columns_raises_error(small_sky_with_nested_sources):
def mean_mag(ra, dec, mag):
return {"ra": ra, "dec": dec, "mean_mag": np.mean(mag)}

with pytest.raises(ValueError):
small_sky_with_nested_sources.reduce(
mean_mag,
"ra",
"dec",
"sources.mag",
meta={"ra": float, "dec": float, "mean_mag": float},
append_columns=True,
).compute()