Skip to content

Commit

Permalink
Merge branch 'main' into issue/308/fine-filtering-on-read-hipcat
Browse files Browse the repository at this point in the history
  • Loading branch information
camposandro committed Jun 10, 2024
2 parents b48f941 + 2e33436 commit d0792af
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 163 deletions.
16 changes: 11 additions & 5 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def crossmatch(
Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm
) = BuiltInCrossmatchAlgorithm.KD_TREE,
output_catalog_name: str | None = None,
require_right_margin: bool = False,
**kwargs,
) -> Catalog:
"""Perform a cross-match between two catalogs
Expand Down Expand Up @@ -137,12 +138,12 @@ def crossmatch(
Custom function:
To specify a custom function, write a class that subclasses the
`AbstractCrossmatchAlgorithm` class, and overwrite the `crossmatch` function.
`AbstractCrossmatchAlgorithm` class, and overwrite the `perform_crossmatch` function.
The function should be able to perform a crossmatch on two pandas DataFrames
from a HEALPix pixel from each catalog. It should return a dataframe with the
combined set of columns from the input dataframes with the appropriate suffixes
and, eventually, a set of extra columns generated by the crossmatch algorithm.
from a partition from each catalog. It should return two 1d numpy arrays of equal lengths
with the indices of the matching rows from the left and right dataframes, and a dataframe
with any extra columns generated by the crossmatch algorithm, also with the same length.
These columns are specified in {AbstractCrossmatchAlgorithm.extra_columns}, with
their respective data types, by means of an empty pandas dataframe. As an example,
the KdTreeCrossmatch algorithm outputs a "_dist_arcsec" column with the distance between
Expand All @@ -166,10 +167,13 @@ def crossmatch(
You may add any additional keyword argument parameters to the crossmatch
function definition, and the user will be able to pass them in as kwargs in the
`Catalog.crossmatch` method.
`Catalog.crossmatch` method. Any additional keyword arguments must also be added to the
`CrossmatchAlgorithm.validate` classmethod by overwriting the method.
output_catalog_name (str): The name of the resulting catalog.
Default: {left_name}_x_{right_name}
require_right_margin (bool): If true, raises an error if the right margin is missing which could
lead to incomplete crossmatches. Default: False
Returns:
A Catalog with the data from the left and right catalogs merged with one row for each
Expand All @@ -183,6 +187,8 @@ def crossmatch(
suffixes = (f"_{self.name}", f"_{other.name}")
if len(suffixes) != 2:
raise ValueError("`suffixes` must be a tuple with two strings")
if other.margin is None and require_right_margin:
raise ValueError("Right catalog margin cache is required for cross-match.")
if output_catalog_name is None:
output_catalog_name = f"{self.name}_x_{other.name}"
ddf, ddf_map, alignment = crossmatch_catalog_data(
Expand Down
144 changes: 123 additions & 21 deletions src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,56 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Tuple
from abc import ABC
from typing import TYPE_CHECKING, Tuple

import numpy as np
import numpy.typing as npt
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.catalog.margin_cache import MarginCacheCatalogInfo
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN

if TYPE_CHECKING:
from lsdb.catalog import Catalog


# pylint: disable=too-many-instance-attributes, too-many-arguments
class AbstractCrossmatchAlgorithm(ABC):
"""Abstract class used to write a crossmatch algorithm"""
"""Abstract class used to write a crossmatch algorithm
To specify a custom function, write a class that subclasses the
`AbstractCrossmatchAlgorithm` class, and overwrite the `perform_crossmatch` function.
The function should be able to perform a crossmatch on two pandas DataFrames
from a partition from each catalog. It should return two 1d numpy arrays of equal lengths
with the indices of the matching rows from the left and right dataframes, and a dataframe
with any extra columns generated by the crossmatch algorithm, also with the same length.
These columns are specified in {AbstractCrossmatchAlgorithm.extra_columns}, with
their respective data types, by means of an empty pandas dataframe. As an example,
the KdTreeCrossmatch algorithm outputs a "_dist_arcsec" column with the distance between
data points. Its extra_columns attribute is specified as follows::
pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))})
The class will have been initialized with the following parameters, which the
crossmatch function should use:
- left: pd.DataFrame,
- right: pd.DataFrame,
- left_order: int,
- left_pixel: int,
- right_order: int,
- right_pixel: int,
- left_metadata: hc.catalog.Catalog,
- right_metadata: hc.catalog.Catalog,
- right_margin_hc_structure: hc.margin.MarginCatalog,
- suffixes: Tuple[str, str]
You may add any additional keyword argument parameters to the crossmatch
function definition, and the user will be able to pass them in as kwargs in the
`Catalog.crossmatch` method. Any additional keyword arguments must also be added to the
`CrossmatchAlgorithm.validate` classmethod by overwriting the method.
"""

extra_columns: pd.DataFrame | None = None
"""The metadata for the columns generated by the crossmatch algorithm"""
Expand Down Expand Up @@ -59,34 +98,61 @@ def __init__(
self.right_margin_catalog_info = right_margin_catalog_info
self.suffixes = suffixes

@abstractmethod
def crossmatch(self) -> pd.DataFrame:
def crossmatch(self, **kwargs) -> pd.DataFrame:
"""Perform a crossmatch"""
l_inds, r_inds, extra_cols = self.perform_crossmatch(**kwargs)
if not len(l_inds) == len(r_inds) == len(extra_cols):
raise ValueError(
"Crossmatch algorithm must return left and right indices and extra columns with same length"
)
return self._create_crossmatch_df(l_inds, r_inds, extra_cols)

def perform_crossmatch(self) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
"""Performs a crossmatch to get the indices of the matching rows and any extra columns
Any additional keyword arguments needed can be added to this method in the subclass, and the user
will be able to pass them through the `Catalog.crossmatch` method.
Returns:
A tuple of:
- a numpy array with the indices of the matching rows from the left table
- a numpy array with the indices of the matching rows from the right table
- a pandas dataframe with any additional columns generated by the algorithm
These all must have the same lengths
"""
raise NotImplementedError(
"CrossmatchAlgorithm must either implement `perform_crossmatch` or overwrite `crossmatch`"
)

# pylint: disable=unused-argument
def validate(self):
# pylint: disable=protected-access
@classmethod
def validate(cls, left: Catalog, right: Catalog):
"""Validate the metadata and arguments.
This method will be called **once**, after the algorithm object has
been initialized, during the lazy construction of the execution graph.
This can be used to catch simple errors without waiting for an
expensive ``.compute()`` call."""
expensive ``.compute()`` call.
This must accept any additional arguments the `crossmatch` method accepts.
"""
# Check that we have the appropriate columns in our dataset.
if self.left.index.name != HIPSCAT_ID_COLUMN:
if left._ddf.index.name != HIPSCAT_ID_COLUMN:
raise ValueError(f"index of left table must be {HIPSCAT_ID_COLUMN}")
if self.right.index.name != HIPSCAT_ID_COLUMN:
if right._ddf.index.name != HIPSCAT_ID_COLUMN:
raise ValueError(f"index of right table must be {HIPSCAT_ID_COLUMN}")
column_names = self.left.columns
if self.left_catalog_info.ra_column not in column_names:
raise ValueError(f"left table must have column {self.left_catalog_info.ra_column}")
if self.left_catalog_info.dec_column not in column_names:
raise ValueError(f"left table must have column {self.left_catalog_info.dec_column}")

column_names = self.right.columns
if self.right_catalog_info.ra_column not in column_names:
raise ValueError(f"right table must have column {self.right_catalog_info.ra_column}")
if self.right_catalog_info.dec_column not in column_names:
raise ValueError(f"right table must have column {self.right_catalog_info.dec_column}")
column_names = left._ddf.columns
if left.hc_structure.catalog_info.ra_column not in column_names:
raise ValueError(f"left table must have column {left.hc_structure.catalog_info.ra_column}")
if left.hc_structure.catalog_info.dec_column not in column_names:
raise ValueError(f"left table must have column {left.hc_structure.catalog_info.dec_column}")

column_names = right._ddf.columns
if right.hc_structure.catalog_info.ra_column not in column_names:
raise ValueError(f"right table must have column {right.hc_structure.catalog_info.ra_column}")
if right.hc_structure.catalog_info.dec_column not in column_names:
raise ValueError(f"right table must have column {right.hc_structure.catalog_info.dec_column}")

@staticmethod
def _rename_columns_with_suffix(dataframe, suffix):
Expand Down Expand Up @@ -116,3 +182,39 @@ def _append_extra_columns(cls, dataframe: pd.DataFrame, extra_columns: pd.DataFr
new_col = extra_columns[col]
new_col.index = dataframe.index
dataframe[col] = new_col

def _create_crossmatch_df(
self,
left_idx: npt.NDArray[np.int64],
right_idx: npt.NDArray[np.int64],
extra_cols: pd.DataFrame,
) -> pd.DataFrame:
"""Creates a df containing the crossmatch result from matching indices and additional columns
Args:
left_idx (np.ndarray): indices of the matching rows from the left table
right_idx (np.ndarray): indices of the matching rows from the right table
extra_cols (pd.DataFrame): dataframe containing additional columns from crossmatching
Returns:
A dataframe with the matching rows from the left and right table concatenated together, with the
additional columns added
"""
# rename columns so no same names during merging
self._rename_columns_with_suffix(self.left, self.suffixes[0])
self._rename_columns_with_suffix(self.right, self.suffixes[1])
# concat dataframes together
self.left.index.name = HIPSCAT_ID_COLUMN
left_join_part = self.left.iloc[left_idx].reset_index()
right_join_part = self.right.iloc[right_idx].reset_index(drop=True)
out = pd.concat(
[
left_join_part,
right_join_part,
],
axis=1,
)
out.set_index(HIPSCAT_ID_COLUMN, inplace=True)
extra_cols.index = out.index
self._append_extra_columns(out, extra_cols)
return out
25 changes: 17 additions & 8 deletions src/lsdb/core/crossmatch/bounded_kdtree_match.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Tuple

import numpy as np
import pandas as pd
import pyarrow as pa

from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch
from lsdb.core.crossmatch.kdtree_utils import _find_crossmatch_indices, _get_chord_distance

if TYPE_CHECKING:
from lsdb import Catalog


class BoundedKdTreeCrossmatch(KdTreeCrossmatch):
"""Nearest neighbor crossmatch using a distance range"""

@classmethod
def validate(
self,
cls,
left: Catalog,
right: Catalog,
n_neighbors: int = 1,
radius_arcsec: float = 1,
require_right_margin: bool = False,
min_radius_arcsec: float = 0,
):
super().validate(n_neighbors, radius_arcsec, require_right_margin)
super().validate(left, right, n_neighbors, radius_arcsec)
if min_radius_arcsec < 0:
raise ValueError("The minimum radius must be non-negative")
if radius_arcsec <= min_radius_arcsec:
raise ValueError("Cross match maximum radius must be greater than cross match minimum radius")

def crossmatch(
def perform_crossmatch(
self,
n_neighbors: int = 1,
radius_arcsec: float = 1,
# We need it here because the signature is shared with .validate()
require_right_margin: bool = False, # pylint: disable=unused-argument
min_radius_arcsec: float = 0,
) -> pd.DataFrame:
) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
"""Perform a cross-match between the data from two HEALPix pixels
Finds the n closest neighbors in the right catalog for each point in the left catalog that
Expand All @@ -57,4 +63,7 @@ def crossmatch(
left_xyz, right_xyz, n_neighbors=n_neighbors, min_distance=min_d_chord, max_distance=max_d_chord
)
arc_distances = np.degrees(2.0 * np.arcsin(0.5 * chord_distances)) * 3600
return self._create_crossmatch_df(left_idx, right_idx, arc_distances)
extra_columns = pd.DataFrame(
{"_dist_arcsec": pd.Series(arc_distances, dtype=pd.ArrowDtype(pa.float64()))}
)
return left_idx, right_idx, extra_columns
10 changes: 10 additions & 0 deletions src/lsdb/core/crossmatch/crossmatch_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from enum import Enum

from lsdb.core.crossmatch.bounded_kdtree_match import BoundedKdTreeCrossmatch
from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch


class BuiltInCrossmatchAlgorithm(str, Enum):
"""Cross-matching algorithms included in lsdb"""

KD_TREE = "kd_tree"
BOUNDED_KD_TREE = "bounded_kd_tree"


builtin_crossmatch_algorithms = {
BuiltInCrossmatchAlgorithm.KD_TREE: KdTreeCrossmatch,
BuiltInCrossmatchAlgorithm.BOUNDED_KD_TREE: BoundedKdTreeCrossmatch,
}
Loading

0 comments on commit d0792af

Please sign in to comment.