Skip to content

Commit

Permalink
Allow path-like objects in the margin cache argument (#353)
Browse files Browse the repository at this point in the history
* Allow for Path objects in margin_cache argument

* Add case in unit test
  • Loading branch information
camposandro committed Jun 10, 2024
1 parent 2e33436 commit 113bdd3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
5 changes: 2 additions & 3 deletions src/lsdb/loaders/hipscat/hipscat_loading_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, List

import pandas as pd
from hipscat.io.file_io import FilePointer
from pandas.io._util import _arrow_dtype_mapping

from lsdb.catalog.margin_catalog import MarginCatalog
Expand All @@ -23,7 +24,7 @@ class HipscatLoadingConfig:
columns: List[str] | None = None
"""Columns to load from the catalog. If not specified, all columns are loaded"""

margin_cache: MarginCatalog | str | None = None
margin_cache: MarginCatalog | FilePointer | None = None
"""Margin cache for the catalog. It can be provided as a path for the margin on disk,
or as a margin object instance. By default, it is None."""

Expand All @@ -35,8 +36,6 @@ class HipscatLoadingConfig:
"""Extra kwargs for the pandas parquet file reader"""

def __post_init__(self):
if self.margin_cache is not None and not isinstance(self.margin_cache, (MarginCatalog, str)):
raise ValueError("`margin_cache` must be of type 'MarginCatalog' or 'str'")
if self.dtype_backend not in ["pyarrow", "numpy_nullable", None]:
raise ValueError("The data type backend must be either 'pyarrow' or 'numpy_nullable'")

Expand Down
11 changes: 6 additions & 5 deletions src/lsdb/loaders/hipscat/read_hipscat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import hipscat as hc
from hipscat.catalog import CatalogType
from hipscat.catalog.dataset import BaseCatalogInfo
from hipscat.io.file_io import FilePointer

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.catalog.catalog import Catalog
Expand All @@ -26,11 +27,11 @@

# pylint: disable=unused-argument
def read_hipscat(
path: str,
path: FilePointer,
catalog_type: Type[CatalogTypeVar] | None = None,
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | None = None,
margin_cache: MarginCatalog | FilePointer | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
Expand All @@ -50,16 +51,16 @@ def read_hipscat(
)
Args:
path (str): The path that locates the root of the HiPSCat catalog
path (FilePointer): The path that locates the root of the HiPSCat catalog
catalog_type (Type[Dataset]): Default `None`. By default, the type of the catalog is loaded
from the catalog info and the corresponding object type is returned. Python's type hints
cannot allow a return type specified by a loaded value, so to use the correct return
type for type checking, the type of the catalog can be specified here. Use by specifying
the lsdb class for that catalog.
search_filter (Type[AbstractSearch]): Default `None`. The filter method to be applied.
columns (List[str]): Default `None`. The set of columns to filter the catalog on.
margin_cache (MarginCatalog | str): The margin cache for the main catalog, provided as a path
on disk or as an instance of the MarginCatalog object. Defaults to None.
margin_cache (MarginCatalog | FilePointer): The margin cache for the main catalog, provided
as a path on disk or as an instance of the MarginCatalog object. Defaults to None.
dtype_backend (str): Backend data type to apply to the catalog.
Defaults to "pyarrow". If None, no type conversion is performed.
storage_options (dict): Dictionary that contains abstract filesystem credentials
Expand Down
8 changes: 5 additions & 3 deletions src/lsdb/loaders/hipscat/read_hipscat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@ from __future__ import annotations

from typing import List, Type, overload

from hipscat.io.file_io import FilePointer

from lsdb.catalog.dataset.dataset import Dataset
from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.loaders.hipscat.abstract_catalog_loader import CatalogTypeVar

@overload
def read_hipscat(
path: str,
path: FilePointer,
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | None = None,
margin_cache: MarginCatalog | FilePointer | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
Expand All @@ -36,7 +38,7 @@ def read_hipscat(
catalog_type: Type[CatalogTypeVar],
search_filter: AbstractSearch | None = None,
columns: List[str] | None = None,
margin_cache: MarginCatalog | str | None = None,
margin_cache: MarginCatalog | FilePointer | None = None,
dtype_backend: str | None = "pyarrow",
storage_options: dict | None = None,
**kwargs,
Expand Down
10 changes: 5 additions & 5 deletions tests/lsdb/loaders/hipscat/test_read_hipscat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import hipscat as hc
import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -82,6 +84,9 @@ def test_catalog_with_margin(
# Provide the margin cache catalog path
catalog_2 = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=small_sky_xmatch_margin_dir)
assert isinstance(catalog_2, lsdb.Catalog)
# Which can also be provided with a Path object
catalog_3 = lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=Path(small_sky_xmatch_margin_dir))
assert isinstance(catalog_3, lsdb.Catalog)
# The catalogs obtained are identical
assert catalog.margin.hc_structure.catalog_info == catalog_2.margin.hc_structure.catalog_info
assert catalog.margin.get_healpix_pixels() == catalog_2.margin.get_healpix_pixels()
Expand All @@ -94,11 +99,6 @@ def test_catalog_without_margin_is_none(small_sky_xmatch_dir):
assert catalog.margin is None


def test_catalog_with_wrong_margin_args(small_sky_xmatch_dir):
with pytest.raises(ValueError, match="must be of type"):
lsdb.read_hipscat(small_sky_xmatch_dir, margin_cache=1)


def test_read_hipscat_subset_with_cone_search(small_sky_order1_dir, small_sky_order1_catalog):
cone_search = ConeSearch(ra=0, dec=-80, radius_arcsec=20 * 3600)
# Filtering using catalog's cone_search
Expand Down

0 comments on commit 113bdd3

Please sign in to comment.